【图解】由浅入深理解机器学习和GPT原理

共 4050字,需浏览 9分钟

 ·

2023-06-20 11:24

原作者:@JayAlammar 翻译:成江东
目前看到的最通俗易懂、由浅入深的图解机器学习和GPT原理的系列文章,这是第一篇,由我和 GPT-4共同翻译完成,分享给大家。
从一个简单的例子开始。假设你正在帮助一个想买房子的朋友。她被报价40万美元购买一个2000平方英尺(185平方米)的房子。这个价格合适吗?在没有参照物的情况下,这很难判断。所以你询问了在同一个社区购买过房子的朋友们,最后得到了三个数据点:


就我个人而言,我的第一反应是计算每平方英尺的平均价格。这个价格是每平方英尺180美元。

欢迎来到你的第一个神经网络!虽然它还没有达到Siri的水平,但现在你已经了解了基本的构建模块。它看起来是这样的:

这样的图表展示了网络的结构以及如何计算预测。计算从左侧的输入节点开始。输入值向右流动。它乘以权重,结果就成为我们的输出。将2,000平方英尺乘以180,我们得到360,000美元。在这个层面上,计算预测就是简单的乘法。但在此之前,我们需要考虑我们将要乘以的权重。这里我们从平均值开始,稍后我们将研究更好的算法,以便在获得更多输入和更复杂模型时进行扩展。找到权重就是我们的“训练”阶段。所以,每当你听到有人在“训练”神经网络时,它只是指找到我们用来计算预测的权重。

这是一个简单的预测模型,它接受输入,进行计算,并给出输出(由于输出可以是连续值,我们所拥有的技术名称是“回归模型”)

注:回归模型是一种用于预测因果关系的统计模型,它通常用于研究与某些因素有关的连续变量。它基于已知数据的线性或非线性方程,通过最小化误差或损失函数来拟合数据,并通过该方程对未知数据进行预测。回归模型可以用于分析多种因素对某一变量的影响,例如在经济学、社会学、医学、工程学等领域中,它经常被用于探索因果关系和预测未来趋势。常见的回归模型包括线性回归、多项式回归、逻辑回归等。

将这个过程可视化(为了简化,将价格单位从1美元换成1000美元。现在我们的权重是0.180而不是180):

更难、更好、更快、更强

我们能否在估计价格方面做得比基于数据点平均值更好呢?首先,定义在这种情况下更好的意义。如果我们将模型应用于我们拥有的三个数据点,它会做得多好?

如图所示,黄线是误差值,黄线长是不好的,我们希望尽可能减小黄线的长度。

平均值:2,058

在这里,我们可以看到实际价格值、预测价格值以及它们之间的差异。然后我们需要对这些差异求平均,以便得到一个表示预测模型中有多少错误的数字。问题是,第3行的值为-63。如果我们想用预测值和价格之间的差异作为衡量误差的标准,我们必须处理这个负值。这就是为什么我们引入了一个额外的列,显示误差的平方,从而消除了负值。这就是我们定义更好模型的标准 - 更好的模型是误差较小的模型。误差是数据集中每个点误差的平均值。对于每个点,误差是实际值和预测值之间的差异的平方。这称为均方误差。将其作为指导来训练我们的模型使其成为我们的损失函数(也称为成本函数)。

现在我们已经定义了衡量更好模型的标准,尝试一些其它权重值,并将它们与我们的平均值进行比较:

通过改变权重,我们无法在模型上做出太多改进。但是,如果我们添加一个偏置值,我们可以找到改进模型的值。现在我们添加了这个b值到线性公式中,我们的预测值可以更好地逼近我们的实际值。在这种情境下,我们称之为“偏置”。这使得我们的神经网络看起来像这样:

我们可以概括地说,一个具有一个输入和一个输出的神经网络(剧透警告:没有隐藏层)看起来像这样:

在这个图中,W 和 b 是我们在训练过程中找到的值,X 是我们输入到公式中的值(例如,我们的示例中的房屋面积(平方英尺))。Y 是预测的价格。现在,计算预测使用这个公式:

因此,我们当前的模型通过将房屋面积作为 x 插入,使用这个公式来计算预测:

训练你想尝试训练我们的玩具神经网络吗?通过调整权重和偏置来最小化损失函数。你能让误差值低于799吗?

自动化恭喜你手动训练了你的第一个神经网络!看看如何自动化这个训练过程。下面是另一个带有自动驾驶功能的示例。这些是 GD Step 按钮。它们使用一种称为“梯度下降”的算法,尝试向正确的权重和偏置值迈进,以最小化损失函数。

这两个新图表可以帮助你在调整模型参数(权重和偏置)时跟踪误差值。跟踪误差非常重要,因为训练过程就是尽可能减少这个误差。梯度下降如何知道它的下一步应该在哪里?可以利用微积分。你看,我们知道我们要最小化的函数(损失函数,所有数据点的(y_ - y)²的平均值),也知道当前输入的值(当前的权重和偏置),损失函数的导数告诉我们应该如何调整 W 和 b 以最小化误差。想了解更多关于梯度下降以及如何使用它来计算新的权重和偏置的信息,请观看 Coursera 机器学习课程的第一讲。

引入第二变量

房子的大小是决定房价的唯一变量吗?显然还有很多其他因素。添加另一个变量,看看我们如何调整神经网络来适应它。假设你的朋友做了更多的研究,找到了更多的数据点。她还发现了每个房子有多少个浴室:


我们的两变量神经网络如下所示:

现在我们需要找到两个权重(每个输入一个)和一个偏置来创建我们的新模型。计算Y的公式如下:

但是我们如何找到w1和w2呢?这比我们只需要考虑一个权重值时要复杂一些。多一个浴室对我们预测房价的影响有多大呢?尝试找到合适的权重和偏置。从这里开始,你会看到随着输入数量的增加,我们面临的复杂性也在增加。我们开始失去创建简单二维形状的能力,这使得我们不能一眼就能看出模型的特点。相反,我们主要依赖于在调整模型参数时,误差值是如何变化的。

我们再次依靠可靠的梯度下降法来帮助我们找到合适的权重和偏置。

特征

现在你已经了解了具有一个和两个特征的神经网络,你可以尝试添加更多特征并使用它们来计算预测值。权重的数量将继续增长,当我们添加每个新特征时,我们需要调整梯度下降的实现,以便它能够更新与新特征相关的新权重。这里需要注意的是,我们不能盲目地将我们所知道的所有信息都输入到网络中。我们必须在输入模型的特征上有所选择。特征选择/处理是一个拥有自己一套最佳实践和注意事项的独立学科。如果你想看一个关于检查数据集以选择输入预测模型的特征的过程的例子,请关注公众号:「数据STUDIO」。这是一个学习机器学习算法的好地方!里面包含了很多相关项目。Omar EL Gabry在其中讲述了他解决Kaggle泰坦尼克挑战的过程。Kaggle提供了泰坦尼克号上乘客的名单,包括姓名、性别、年龄、船舱以及该人是否幸存等数据。挑战的目标是建立一个模型,根据其他信息预测一个人是否幸存。

分类

继续调整我们的例子。假设你的朋友给你一份房子清单。这次,她标注了哪些房子在她看来具有合适的大小和浴室数量:


她需要你使用这个方法来创建一个模型,根据房子的大小和浴室数量来预测她是否会喜欢这个房子。你将使用上面的列表来构建模型,然后她将使用这个模型来对许多其他房子进行分类。在这个过程中还有一个额外的改变,那就是她还有另一个包含10个房子的列表,她已经对这些房子进行了标记,但她没有告诉你。这个另外的列表将在你训练模型后用来评估你的模型,从而确保你的模型能够把握她实际喜欢的房子特征。我们迄今为止所尝试的神经网络都是进行“回归”操作的,它们计算并输出一个“连续”的值(输出可以是4,或100.6,或2143.342343)。然而,在实践中,神经网络更常用于“分类”类型的问题。在这些问题中,神经网络的输出必须是一组离散值(或“类别”),如“好”或“坏”。实践中的工作原理是,我们将会得到一个模型,该模型会表明某个房屋是“好”的可能性为75%,而不仅是简单地输出“好”或“坏”。

在实践中,我们可以将我们已经看到的网络转换成一个分类网络,让它输出两个值——一个值代表某个个类别(我们现在的类别是“好”和“坏”)。然后我们将这些值通过一个叫做“softmax”的操作。softmax的输出是每个类别的概率。例如,假设网络的这一层输出“好”为2,“坏”为4,如果我们将[2, 4]输入到softmax中,它将返回[0.11, 0.88]作为输出。这意味着网络有88%的把握认为输入的值是“坏”的,我们的朋友可能不喜欢那个房子。

Softmax函数接受一个数组作为输入,并输出一个相同长度的数组。注意到它的输出都是正数,并且总和为1,这在输出概率值时非常有用。另外,尽管4是2的两倍,但它的概率不仅是2的两倍,而且是2的八倍。这是一个有用的特性,它可以夸大输出之间的差异,从而改善我们的训练过程。

如您在最后两行中所看到的,softmax可以扩展到任意数量的输入。所以现在如果我们的朋友添加了第三个标签(比如说“不错,但我得把一间房子租给airbnb”),softmax可以扩展以适应这种变化。花点时间探索一下网络的形状,看看当您改变特征数量(x1、x2、x3等)(可以是面积、浴室数量、价格、靠近学校/工作的距离等)和类别数量(y1、y2、y3等)(可以是“太贵了”、“性价比高”、“如果我把一间房子租给airbnb就好了”、“太小了”)时,网络是如何变化的。

您可以在我为本文创建的这个笔记本中看到如何使用 TensorFlow 创建和训练这个网络的示例。真正的动力如果您已经读到这里了,我必须向您揭示我写这篇文章的另一个动力。这篇文章旨在作为一个更加温和的 TensorFlow 教程入门。如果您现在开始学习《MNIST 机器学习初学者》,并遇到了这张图:

原文地址:https://jalammar.github.io/visual-interactive-guide-basics-neural-networks/

基于ChatGPT,论文写作工具

国内可用 ChatGPT 客户端下载

数据分析入门:统计学基础知识总结

可能是全网最全的速查表:Python Numpy Pandas Matplotlib 机器学习 ChatGPT


浏览 12
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报