AIxiv專欄是機(jī)器之心發(fā)布學(xué)術(shù)、技術(shù)內(nèi)容的欄目。過去數(shù)年,機(jī)器之心AIxiv專欄接收報(bào)道了2000多篇內(nèi)容,覆蓋全球各大高校與企業(yè)的頂級實(shí)驗(yàn)室,有效促進(jìn)了學(xué)術(shù)交流與傳播。如果您有優(yōu)秀的工作想要分享,歡迎投稿或者聯(lián)系報(bào)道。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com
近日,中國電信翼支付針對大模型推理加速的最新研究成果《Falcon: Faster and Parallel Inference of Large Language Models through Enhanced Semi-Autoregressive Drafting and Custom-Designed Decoding Tree》已被 AAAI 2025 接收。論文中提出的 Falcon 方法是一種增強(qiáng)半自回歸投機(jī)解碼框架,旨在增強(qiáng) draft model 的并行性和輸出質(zhì)量,以有效提升大模型的推理速度。Falcon 可以實(shí)現(xiàn)約 2.91-3.51 倍的加速比,在多種數(shù)據(jù)集上獲得了很好的結(jié)果,并已應(yīng)用到翼支付多個(gè)實(shí)際業(yè)務(wù)中。
論文地址:https://arxiv.org/pdf/2412.126391. 研究背景大型語言模型 (LLMs) 在各種基準(zhǔn)測試中展現(xiàn)了卓越的表現(xiàn),然而由于自回歸 (AR) 解碼方式,LLMs 在推理過程中也面臨著顯著的計(jì)算開銷和延遲瓶頸。為此,研究學(xué)者提出 Speculative Decoding (投機(jī)采樣) 方法。Speculative Decoding 會選擇一個(gè)比原始模型 (Target Model) 輕量的 LLM 作為 Draft Model,在 Draft 階段使用 Draft Model 連續(xù)生成若干個(gè)候選 Token。在 Verify 階段,將得到的候選 Token 序列放入到原始 LLM 做驗(yàn)證 & Next Token 生成,實(shí)現(xiàn)并行解碼。通過將計(jì)算資源導(dǎo)向于驗(yàn)證預(yù)先生成的 token,Speculative Decoding 大大減少了訪問 LLM 參數(shù)所需的內(nèi)存操作,從而提升了整體推理效率,F(xiàn)有的投機(jī)采樣主要采用兩種 Draft 策略:自回歸 (AR) 和半自回歸 (SAR) draft。AR draft 順序生成 token,每個(gè) token 依賴于前面的 token。這種順序依賴性限制了 draft 模型的并行性,導(dǎo)致顯著的時(shí)間開銷。相比之下,SAR draft 同時(shí)生成多個(gè) token,增強(qiáng)了 draft 過程的并行化。然而,SAR draft 的一個(gè)重要局限是它無法完全捕捉相同 block 內(nèi) draft tokens 之間的相互依賴關(guān)系,可能導(dǎo)致生成的 token 接受率較低。因此,在投機(jī)采樣中,平衡低 draft 延遲與高推測準(zhǔn)確性以加速 LLMs 的推理速度,是一個(gè)重大挑戰(zhàn)。為此,翼支付提出了 Falcon,一個(gè)增強(qiáng)的半自回歸(SAR)投機(jī)解碼框架,旨在增強(qiáng) draft model 的并行性和輸出質(zhì)量,從而提升 LLMs 的推理效率。Falcon 集成了 Coupled Sequential Glancing Distillation(CSGD)方法,提高了 SAR draft model 的 token 接受率。此外,F(xiàn)alcon還設(shè)計(jì)了一種專門的 decoding tree 來支持 SAR 采樣,使得 draft model 可以在一次前向傳播中生成多個(gè) token,并且也能夠支持多次前向傳播。這種設(shè)計(jì)有效提升 LLMs 對 token 的接受率,進(jìn)一步加快了推理速度。2.研究方法Falcon的架構(gòu)如圖 1 所示,可以看到,該半自回歸解碼框架主要由三個(gè)組件構(gòu)成:Embedding Layer、LM-Head和半自回歸解碼 Head。
圖 1 Falcon 框架圖具體來講,F(xiàn)alcon 將一個(gè)時(shí)間步長之前的連續(xù)特征序列和當(dāng)前 token 序列連接起來,以同時(shí)預(yù)測接下來的 k 個(gè)標(biāo)記。例如,當(dāng) k = 2 時(shí),F(xiàn)alcon 使用初始特征序列 (f1, f2) 和提前一個(gè)時(shí)間步長的標(biāo)記序列 (t2, t3) 來預(yù)測特征序列 (f3, f4)。隨后,將預(yù)測得到的特征 (f3, f4) 與下一個(gè)標(biāo)記序列 (t4, t5) 連接,形成新的輸入序列。這個(gè)新輸入序列用于預(yù)測后續(xù)的特征序列 (f5, f6) 和標(biāo)記序列 (t6, t7),從而促進(jìn) draft 過程的繼續(xù)。Draft model 多次 forward 之后生成的 token 被組織成樹結(jié)構(gòu),輸入到大模型中進(jìn)行 verify,通過 verify 的 token 被大模型接收,并基于此基礎(chǔ)開始下一個(gè)循環(huán)。2.1 Coupled Sequential Glancing Distillation當(dāng)前推測解碼方法的準(zhǔn)確性相對較低,主要原因是 token 之間的上下文信息不足。CSGD 通過用真實(shí) token 和 hidden states 替換一些初始預(yù)測來改善這一點(diǎn),將正確信息重新注入解碼過程中,從而提高后續(xù)預(yù)測的準(zhǔn)確性和連貫性。模型結(jié)構(gòu)及訓(xùn)練流程如下圖:
圖 2 CGSD 方法示意圖在訓(xùn)練過程中,一個(gè)時(shí)間步長之前的連續(xù)特征序列和當(dāng)前 token 序列連接起來,并輸入到 draft model 中,形成一個(gè)融合序列,其維度為 (bs, seq_len, 2 * hidden_dim)。draft model 由一個(gè)混合 Transformer 網(wǎng)絡(luò)組成,該網(wǎng)絡(luò)包括兩層 LSTM、Relaxed Causal-Masked 多頭注意力機(jī)制,以及 MLP 網(wǎng)絡(luò)。其中 LSTM 網(wǎng)絡(luò)將融合序列的維度減少到 (bs, seq_len, hidden_dim),并保留關(guān)于過去 token 的信息,從而提高模型的準(zhǔn)確性。Relaxed Causal-Masked 多頭注意力機(jī)制能夠在保持因果關(guān)系的同時(shí),專注于輸入序列的相關(guān)部分。MLP 層進(jìn)一步處理這些信息,以做出最終預(yù)測。當(dāng)序列首次通過 draft model 后,會生成初始的 token 預(yù)測
。然后,我們計(jì)算 draft model 的預(yù)測與真實(shí) token Y 之間的漢明距離,以此來衡量預(yù)測的準(zhǔn)確性。接下來,我們將一定數(shù)量連續(xù)預(yù)測的 token 序列
和特征序列
替換為來自 LLMs 的正確 token 序列
和特征序列
。CSGD 與傳統(tǒng)的 glancing 方法不同,后者僅隨機(jī)替換 token。相反,CSGD 選擇性地同時(shí)替換預(yù)測之前的連續(xù) token 和特征序列,如圖 2 中虛線框標(biāo)注的 choice 1、choice 2、choice3 所示。這種方法增強(qiáng)了對 token 間的關(guān)系的理解,并確保 draft model 能夠有效利用提前時(shí)間步長的 token 序列,這在 SAR 解碼中尤為重要。隨后,修正后的 token 和特征序列被重新輸入到 draft model 中以計(jì)算訓(xùn)練損失。在訓(xùn)練過程中,我們采用了知識蒸餾,損失函數(shù)包括 draft model 的輸出特征與真實(shí)特征之間的回歸損失以及蒸餾損失,具體的損失函數(shù)如下:
2.2 Custom-Designed Decoding Tree當(dāng)前基于樹的推測解碼方法通過在每個(gè)起草步驟生成多個(gè) draft token 來提升推測效率。然而,這些方法仍然需要 draft model 按順序生成 token,這限制了推測效率的進(jìn)一步提高。為了解決這一局限性,CDT (Custom-Designed Decoding Tree) 支持 draft model 在一次前向傳遞中生成多個(gè) token (k 個(gè)),并且在每個(gè) draft 步驟中支持多次前向傳遞。因此,與現(xiàn)有方法相比,CDT 生成的草稿標(biāo)記數(shù)量是其 k 倍。Draft model 多次 forward 之后,生成的 token 被組織成樹結(jié)構(gòu),輸入到大模型中進(jìn)行 verify。LLM 使用基于樹的并行解碼機(jī)制來驗(yàn)證候選 token 序列的正確性,被接受的 token 及其相應(yīng)的特征序列會在后續(xù)繼續(xù)進(jìn)行前向傳遞。在傳統(tǒng)的自回歸(AR)解碼中,使用因果掩碼,其結(jié)構(gòu)為下三角矩陣。它確保了前面的 token 不能訪問后面的信息。相比之下,F(xiàn)alcon 采用了一種 causal 因果掩碼 (如圖 3 所示),允許模型訪問同一 k*k 的 block 內(nèi)的 token 以及相應(yīng)的之前的連續(xù) token。這一增強(qiáng)顯著提高了 drafter 生成 token 的效率,使 LLM 能夠同時(shí)驗(yàn)證更多的 token,從而加快了 LLM 的整體推理速度。
圖 3 Custom-Designed Decoding Tree 方法示意圖3.實(shí)驗(yàn)結(jié)果我們在多個(gè)數(shù)據(jù)集和多個(gè)模型上進(jìn)行了廣泛的實(shí)驗(yàn),驗(yàn)證了本文方法的有效性。和現(xiàn)有的方法相比,F(xiàn)alcon 展現(xiàn)了優(yōu)越的性能,具體如下圖:
圖 4 Falcon 實(shí)驗(yàn)結(jié)果圖4.業(yè)務(wù)潛力Falcon 大模型可以實(shí)現(xiàn)約 2.91-3.51 倍的加速比,相當(dāng)于同等條件下推理成本下降至約原先的 1/3,從而大幅降低了大模型推理計(jì)算相關(guān)成本。當(dāng)前,F(xiàn)alcon 技術(shù)已轉(zhuǎn)化至翼支付大模型產(chǎn)品 InsightAI 平臺,并已服務(wù)諸如翼支付數(shù)字人客服、借錢-翼小橙、人力-翼點(diǎn)通、財(cái)務(wù)-翼小財(cái)?shù)榷鄠(gè)業(yè)務(wù)應(yīng)用。5.總結(jié)投機(jī)采樣是大模型推理加速的一個(gè)核心方法。當(dāng)前,主要的挑戰(zhàn)是如何提升 draft model 的準(zhǔn)確率、采樣效率,并提升大模型的驗(yàn)證效率。文章提出了 Falcon 方法,一種基于增強(qiáng)半自回歸投機(jī)解碼框架。Falcon 通過 CSGD 這種訓(xùn)練方法以及半自回歸的模型設(shè)計(jì),顯著提升了 draft model 的預(yù)測準(zhǔn)確率以及采樣效率。此外,為了讓大模型能驗(yàn)證更多的 token,本文精心設(shè)計(jì)了一個(gè) decoding tree,有效提升了 draft model 的效率,從而提升了驗(yàn)證效率。Falcon 在多種數(shù)據(jù)集上可以實(shí)現(xiàn)約 2.91-3.51x 的加速比并應(yīng)用到翼支付的眾多業(yè)務(wù)中,獲得了很好的效果。6. 公司簡介天翼電子商務(wù)有限公司(翼支付)是中國電信集團(tuán)有限公司成員企業(yè)。公司堅(jiān)持以科技創(chuàng)新引領(lǐng)戰(zhàn)略新興業(yè)務(wù)發(fā)展,運(yùn)用大數(shù)據(jù)、人工智能、區(qū)塊鏈等關(guān)鍵技術(shù),積極探索支付 + AI 服務(wù)體驗(yàn),將 AI 應(yīng)用于支付服務(wù)、智能客服、風(fēng)險(xiǎn)管理等多個(gè)方面,以更好推動產(chǎn)業(yè)數(shù)字化轉(zhuǎn)型升級,助力數(shù)字中國建設(shè)發(fā)展。THE END轉(zhuǎn)載請聯(lián)系本公眾號獲得授權(quán)