小宇:閃客閃客,現(xiàn)在的 AI 好神奇呀,你能給我講講它的原理嗎?閃客:你個菜雞,連最基本的機器學習是什么都不知道,就妄想一下子了解現(xiàn)在 AI 原理?
小宇:額,注意你的態(tài)度!那你說怎么辦嘛!
閃客:現(xiàn)在你先忘掉 AI,忘掉所有的什么 ChatGPT、大模型、深度學習、機器學習、神經(jīng)網(wǎng)絡這些概念。
小宇:哦好,雖然我本來就沒聽說過這些。線性回歸
閃客:啊這... 好吧,我們來一個場景。我想研究雞的數(shù)量和腿的數(shù)量的關系,于是我列了一個表格。
雞 腿
5 10
7 14
8 16
9 18
那我問你,假如雞的數(shù)量是 10,那么腿的數(shù)量是多少?小宇:額,你是不是把我當傻帽呀,我不看你這表也知道,腿的數(shù)量就是雞數(shù)量的 2 倍嘛,當然是 20 了!閃客:沒錯,你直接找到了雞和腿數(shù)量之間的規(guī)律,是嚴格符合 y = 2x 的函數(shù)關系。假如世界上所有的事情都能找到其對應的嚴格的函數(shù)關系,那該多好,這就是早期機器學習符號主義的愿景。
畫外音:機器學習的符號主義(Symbolic AI 或 Symbolic Machine Learning) 是人工智能(AI)領域的一種方法,主要基于符號和規(guī)則來表示知識和推理。這種方法與現(xiàn)代機器學習方法(例如深度學習)形成了鮮明對比,后者依賴于神經(jīng)網(wǎng)絡和大量數(shù)據(jù)的模式識別。符號主義在20世紀70-90年代被廣泛應用,是人工智能早期的主要研究方向之一。小宇:誒?這看起來非?茖W嚴謹呀,為什么這樣做不行呢?
閃客:如果能實現(xiàn)這個愿景固然是好的,但人們還是低估了這個世界的復雜程度。想想看,如果讓你用一個函數(shù)來預測股票是漲還是跌,這可能嗎?
小宇:總感覺理論上是可行的,但實際上應該做不到,不然我也不會在這學什么機器學習了哈哈。
閃客:是的,這種看似能夠找到規(guī)律的事情都做不到,更別提人類智慧這種的復雜問題了。
小宇:哎,那這可怎么辦呢?
閃客:別急,咱先別考慮那么遠的問題,我先給你出一個比剛剛數(shù)雞腿更復雜點的問題,你找找看下面 X 和 Y 的關系。
X Y1 2.6
2 3.0
3 3.7
4 4.5
5 4.4
6 4.9
7 6.0
8 6.2
9 6.4
10 7.2
小宇:額,總感覺有點規(guī)律,但又不能一下子看出來,有點燒腦。
閃客:確實,不過我們把這些點畫在坐標軸上,你再看看呢。
小宇:哇!這么看清晰了好多,不過還是不能一下子看出來什么。
閃客:那我再加一條線呢?
小宇:哎呀!這感覺已經(jīng)找到規(guī)律了,大概就是 y = 0.5x + 2 嘛!閃客:沒錯,你居然直接把函數(shù)說出來了。
小宇:你圖都畫成這樣了,我還說不出函數(shù),那就太不應該了。不過我猜到你接下來要說什么了,就是如何找到這個函數(shù)對吧?
閃客:沒錯,直覺上,我們是想讓這條線盡可能靠近所有點,但怎么用數(shù)學或計算機語言表達"靠得近",就是個問題了。
小宇:emmm,好像不太容易想到,沒想到這么簡單直觀的問題,要是用嚴肅的數(shù)學語言描述,還挺難的。
閃客:是的,我給你加幾條線,你看看有沒有啟發(fā)。
小宇:啊!我明白了,可以用每個點到這條線的偏離距離的總和,來表示點與線的“貼合程度”,這個數(shù)越小越好。
閃客:沒錯!所以我們就可以定義如下的損失函數(shù),來表示這條線和這些點的偏離程度,只要找到這個函數(shù)的最小值即可!
小宇:額,你這太不絲滑了呀,前面還一個公式都沒有,怎么突然冒出來這么個東西。
損失函數(shù)
閃客:哈哈,本來想給你嚇回去的,但既然你沒走,那我們就專門來聊聊這個"損失函數(shù)"到底是個啥東西,為什么叫它"損失"。
小宇:是因為算出來的數(shù)特別讓人"損失信心"嗎?
閃客:哈哈哈,這個腦洞不錯,但其實它的損失更像是"我們和完美結果之間的差距"。差距越大,損失就越大,差距越小,損失就越小。
小宇:哇,這個解釋好理解!
閃客:來,我們先從直觀的定義開始。假設某個點的真實值是 y,而我們的預測值是。你覺得兩者的誤差可以怎么表示?
小宇:很簡單呀,直接用 y不就行了?
閃客:不錯!這叫"誤差"或"偏差"。但問題來了,你覺得要是我們把所有點的誤差加起來,有啥問題?
小宇:嗯~正的誤差和負的誤差會互相抵消,最終看起來像沒什么偏差?
閃客:沒錯,像剛剛的那幾個 XY 的點,如果按這種算法來評估,就有可能找到一種驢唇不對馬嘴的預測,但它的損失卻是 0!
小宇:哈哈,確實離了大譜了,那可咋辦呢?
閃客:為了不讓誤差"藏著掖著",我們可以給它取個絕對值,這樣正負誤差都成了正的:
小宇:哦,這樣挺公平的呀。誒等等,這又有個新的數(shù)學公式,你得解釋解釋。
閃客:額,你是沒上過初中么?這個符號就是求和符號,表示把所有的 y - 的值都累加起來。比如把等差數(shù)列寫成求和符號的形式就是這樣。
小宇:哦懂了,這好像確實是初中就學過的,嘿嘿。
閃客:回過頭來看,這樣確實很公平,但有個小問題,就是絕對值有"尖點",數(shù)學優(yōu)化的時候不太友好,計算起來跟被卡在牙縫里一樣麻煩。
小宇:嗯確實,做題的時候其實最討厭碰到絕對值符號,還得分段討論,有一種情況沒想全就要扣分,最頭疼了。
閃客:所以我們更喜歡"平方誤差",就是把誤差平方后再加起來:
小宇:哇!這的確是個絕妙的辦法呀,平方之后,正負誤差都成正的!而且大的誤差更顯眼,就像班里成績特別差的同學會被老師特別關照一樣。
閃客:哈哈哈,沒錯!我們再平均一下,去掉樣本數(shù)量大小因素的影響,這就叫"均方誤差"(Mean Squared Error, MSE,看起來是不是又簡單又合理?
小宇:嗯,這次終于沒有突然甩出高大上的東西,我的信心回來了一點。
閃客:好了,找到了損失函數(shù),還記得我們要干啥不?
小宇:記得,讓損失函數(shù)最!
閃客:不錯,這時候我們得把 表示出來,我們可以假設預測的直線的方程是 y = wx + b,像下面這樣。
不過我們可以先簡化一點,認為這條直線穿過原點,這樣就可以少個 b。
這個時候帶入 MSE 中,就是
我們想要計算的就是,w 為多少的時候,這個損失函數(shù)的值最小。
小宇:完了完了,我已經(jīng)頭疼了,這里咋這么多字母,我已經(jīng)暈了。
閃客:別急,這些字母里其實只有 w 是未知的,其他的都是已知數(shù)。我們舉個簡單的例子就明白了了。我們先不看上面那個復雜的例子,假設 x=[1,2,3,4] y=[1,2,3,4] 這樣傻子都能看出來規(guī)律對吧,我們就用這個來舉例。
小宇:哈哈這個簡單,不用算也知道就是 y = x
閃客:沒錯,我們就用這個算一下,把這里的 x 和 y 的值都代入到剛剛的損失函數(shù)中。
接下來就是一個標準的求函數(shù) L 的極小值點的過程,這種苦力活我怎么可能自己做呢,交給 AI 吧。
小宇:哈哈,你可真懶,不過這過程解釋得真細致呀,要讓你講肯定不能這么有耐心。我再補個圖吧,剛剛 w = 1 就表示預測直線的方程是 y = x,就像這樣,確實損失最小呢!
閃客:沒錯,實際上剛剛的
畫成圖就是個拋物線,尋找最小值點就是尋找拋物線的最低點。
小宇:原來如此!誒?那如果回到最初,我們不簡化預測函數(shù)的直線方程,直接是 y = wx + b 呢?這要怎么辦?
閃客:一樣的,這樣最終代入到損失函數(shù)后,就是關于 w 和 b 兩個未知變量的函數(shù),求極值點如果畫成圖的話,就不再是拋物線了,而是三維坐標中的曲面。
這時候就得用偏導數(shù)來計算了,具體太數(shù)學了就不展開了,偏導數(shù)我做了兩個動圖,你可以感悟一下。
小宇:哎呀雖然這動畫很絲滑,但想起來是真燒腦呀,更何況這還是最簡單的形狀了,如果七扭八歪或者維度更高就...
閃客:是的,所以這個時候我們就不能直接硬求解了,得累死你,而且也利用不了計算機的優(yōu)勢。這時候我們可以用另一種更適合計算機一步一步逼近答案的求解方法 -- 梯度下降。
梯度下降
小宇:啊,這么神奇!那快告訴我什么是梯度下降呢?
閃客:別急,直接告訴你可不是我的風格,我們先不要管什么梯度下不下降的,先來想想我們的目的是什么。
小宇:嗯目的我還是清晰的,就是我們想求解一個叫損失函數(shù)的最小值,比如 L(w, b) 甚至更多維的 L(w, w, ..., b)。
閃客:沒錯,但最終目標可不是知道這個最小值是多少。
小宇:哦哦對,我說得不給力,是求解使得這個損失函數(shù)最小的 w 和 b 都是多少。
閃客:沒錯,那你想想看,直接一步到位求出 w 和 b 的值太難了,那我們是不是可以一點一點調(diào)整它們,分多次求解呢?
小宇:一點一點調(diào)整?聽起來好像是個思路,但還是沒太明白怎么調(diào)整。
閃客:沒關系,我們假設個生活中的場景,你現(xiàn)在有一杯咖啡和糖,你怎么調(diào)出符合你口味的甜度呢?
小宇:哦這個我深有感悟,一步到位很難。比如我想要微微甜,那就得先加一點點糖,然后嘗一嘗,然后再加點,再嘗一嘗,直到剛好到我滿意為止。
閃客:沒錯,沒想到你還挺精致的,這就是梯度下降的精髓!
小宇:啊,這和梯度下降有什么關系呢?
閃客:你可以把符合你的口味這個目標當做一個損失函數(shù),糖的量就是損失函數(shù)中的參數(shù),你不能一下子就確定糖這個值是多少,于是只能從一個初始狀態(tài)開始,比如先加一勺糖,然后一點一點變化糖的量。每次加完糖后你品嘗咖啡就是你在計算這次的損失函數(shù),也就是你對口味的喜歡程度。
小宇:啊,我明白了!沒想到生活中的例子這么有啟發(fā)作用!
閃客:對!生活中的很多事都是這樣的,比如做飯調(diào)味、調(diào)音響音質,甚至選衣服搭配顏色,都是通過不斷嘗試和調(diào)整來找到最優(yōu)解。機器學習的梯度下降,也是用這種思路來優(yōu)化參數(shù)的。
小宇:這個思路我明白了。不過你之前說的“梯度”具體是啥呢?
閃客:假如損失函數(shù)只有一個參數(shù),像之前的 L(w),那么梯度就和導數(shù)是一個意思。
如果損失函數(shù)有多個參數(shù),像之前的 L(w,b),那么梯度就是各個參數(shù)的偏導數(shù)。
在這種情況下,梯度是個向量,是所有參數(shù)的偏導數(shù)累加起來的綜合結果。
小宇:額,你這一大堆輸出差點又給我整懵了,向量這個概念確實學過,但總感覺還不直觀,你能形象地給我展示下么?
閃客:沒問題,我們就拿之前三維坐標系下的那個帶兩個參數(shù) w 和 b 的損失函數(shù)來說,對應圖中的這個點,它的梯度是多少呢?
小宇:對 w 和 b 分別求偏導?
閃客:沒錯,在圖中,對 w 求偏導就是把 b = 0 這個平面和曲面的交線求導數(shù)。
把視角轉一下就清晰了。
小宇:原來如此!那對 b 求偏導呢?
閃客:也是一樣,線畫 w = -1 這個平面和曲面的交線。
從側面看,這條交線已經(jīng)在最低點了,所以 b 的偏導數(shù)就是 0。
所以把這兩股偏導數(shù)的力量合在一起,就是最終的向量,也就是梯度。
小宇:我明白了!其實就是找個坡度最大的方向往下滑,直到滑到最低點。
閃客:沒錯,不過這里的圖只是為了讓你形象理解梯度的意思,實際計算的時候不用考慮那么多,直接求各參數(shù)的偏導數(shù)就行了。
小宇:誒,那算出偏導數(shù)之后,要怎么樣呢?
閃客:簡單!每次都沿著梯度的反方向,走一小步,也就是你說的往下滑。公式寫出來是這樣的:
小宇:哇,這么簡單呀,其實就是每個參數(shù)每次都變化自己偏導數(shù)那么大的值就好了。
閃客:沒錯!不過這樣的話有個小問題,就是每次變化的這個量,太大了容易走過了錯過最低點,太小了又太磨嘰,所以我們乘以一個學習率 η 來調(diào)整一下速度。
小宇:哦還真是,人類真是好聰明呀!
閃客:哈哈是呀。咱們找到了梯度下降的求解方法,你來實踐一下吧;氐侥莻最簡單的題目,假設 x 和 y 的數(shù)據(jù)如下:x=[1,2,3,4] y=[1,2,3,4] ,求一下 y = wx 中的 w 是多少。
雖然傻子也能直接看出 y = x 是最終的解,不過我們就用這個來舉例實戰(zhàn)一下,你來用梯度下降的方法求一下 w 的值。
小宇:好的,不過我學你,這種小事兒我也懶得自己算了,交給 AI 吧!
閃客:哈哈真不賴,活學活用呀,這 AI 直接把圖都幫我們畫出來了,圖里可以看到損失函數(shù)的值 Loss 再逐漸降低為 0,而我們要計算的權重 w 的值在不斷接近 1。之后你看到再復雜的機器學習或者深度學習等過程的展示,最核心的其實就是這兩個東西的變化罷了。
小宇:哇,似乎有點 GET 到 AI 的核心邏輯了!我理解更高維度也就是更多參數(shù)的梯度下降求解,和這個步驟基本的思路是一致的。
閃客:沒錯,至于梯度下降的改進版本,比如動量法、Adam 優(yōu)化器等,以及更多計算模型,比如神經(jīng)網(wǎng)絡、卷積神經(jīng)網(wǎng)絡等,都是在這個核心思路的基礎上迭代出來的。
小宇:厲害了,這次講得還挺耐心!
閃客:哎呀,不知不覺又到飯點了,今天講的給你畫了這么多圖很累的,請我吃個飯吧。
小宇:哦才想起來我家里洗的衣服還在洗衣機里呢,我得回去晾衣服啦,下次吧。
閃客:哦~
來源:閃客
編輯:余蔭鎧