在 CFU Playground 上加速 MLPerf™ Tiny 影像分類模型 #4:脈動陣列與矩陣乘法單元

Yeecy
54 min readJan 8, 2024

--

4.1 脈動陣列

在一眾專門針對深度學習模型訓練和推理的硬體裡,除去全球最大 AI 軍火商老黃手下的 NVIDIA 計算卡外,普羅大眾能夠有所耳聞的大概就只剩 Google 設計並佈署在 Google Cloud 的 TPU。Google 在自家的 TPU 裡使用了脈動陣列作為矩陣乘法單元,那什麼是脈動陣列呢?我們用個簡單的例子演示。

給定兩個定義域為整數的單變數函數 fg 且兩者值域相同,若該值域有加法與乘法之定義,則可以定義兩者的一維卷積為 (f * g)[x] = ∑ f[t] × g[xt],其中 t 的取值從負無窮到正無窮,如果讀者沒有學過的話不需緊張,只需知道其計算方法即可,另外因為 Medium 沒辦法渲染 LaTeX 公式,相關的數學式子非必要的話都會以 Unicode 呈現。

fg 都為數字 1 到 6 之間的離散均勻分布,意即兩者在 1 到 6 之間的整數對應到的值皆為 1/6,其他整數皆對應到到 0,假設今天有兩個公平的六面骰子,在機率學中我們知道擲出兩顆骰子後點數和為 n 的機率為 (f * g)[n],按卷積定義展開可以得到點數和為 n 的表達式如下。

(f * g)[n] = f[1] × g[n - 1]
+ f[2] × g[n - 2]
+ f[3] × g[n - 3]
+ f[4] × g[n - 4]
+ f[5] × g[n - 5]
+ f[6] × g[n - 6]

n 用 2 代入上式得到 (f * g)[2] = f[1]×g[1] + f[2]×g[0] + f[3]×g[-1] + f[4]×g[-2] + f[5]×g[-3] + f[6]×g[-4] = 1/36 + 0 + 0 + 0 + 0 + 0 = 1/36 ≈ 0.0278;同理將 n 用 7 代入可以得到 (f * g)[7] = f[1]×g[6] + f[2]×g[5] + f[3]×g[4] + f[4]×g[3] + f[5]×g[2] + f[6]×g[1] = 1/6 ≈ 0.1667,其餘的情況讀者可以自行計算。

以下是個簡單的程式,模擬骰一百萬次骰子的結果並統計出現機率,可以作為上面式子的旁證。

# Copyright (c) 2023-2024 Chung-Yi Chen (Yeecy)
# Probability Simulation of the Sum of Two Fair Dice

from random import randint

trials = 1000000
count = [0 for i in range(13)] # count[0] and count[1] are not used

for trial in range(trials):
count[randint(1, 6) + randint(1, 6)] += 1

for dice_sum in range(2, 13):
print(f'P({dice_sum}) = {count[dice_sum] / trials}')

筆者執行程式後得到輸出如下,可以發現模擬結果跟卷積計算出的 (f * g)[2] 和 (f * g)[7] 結果相近。

P(2) = 0.027614
P(3) = 0.055674
P(4) = 0.083251
P(5) = 0.111688
P(6) = 0.138833
P(7) = 0.165672
P(8) = 0.139029
P(9) = 0.111056
P(10) = 0.083302
P(11) = 0.056016
P(12) = 0.027865

現在我們知道一維卷積的計算過程而且知道一維卷積確實有用,那麼現在是時候來看看怎麼把一維卷積的運算轉換為脈動陣列。

假設今天要設計一個硬體專門計算兩顆六面骰子的所有點數和的出現機率,且兩顆骰子的點數出現機率可能不相同,那麼該怎麼做呢?脈動陣列基本上可以分成兩個部分,一是處理單元,二是處理單元之間的接線,我們先來講一下大概的思路。

首先從前面的表達式中可以看到總共有六組 f[t] × g[xt] 的計算,那麼我們自然而然需要六個處理單元,接下來我們將處理單元串起來,一組連線的方向從左至右,另一組從右至左,如下圖所示。因為機率值為六個處理單元的和,我們需要用加法器算出最終的 (f * g)[n],不過之後為了簡潔起見,將不會畫出加法器。

因為 Medium 的程式碼區塊排版問題,部分示意圖會以內嵌 Gist 呈現

處理單元需要做的計算非常簡單,那就是單純把兩個輸入機率相乘並輸出,此處用 Python 描述行為如下,如果讀者有興趣可自行改寫為 Verilog 試試,上圖中處理單元左邊的 --> 在程式碼中以 left_in 表示,右邊的 <--right_in 表示。

# Copyright (c) 2023-2024 Chung-Yi Chen (Yeecy)
# Computation of the Dice Sum Processing Element

def compute(n, left_in, right_in, output):
output = left_in * right_in

接下來我們需要決定資料要如何流動,根據上圖的設計,對於一個處理單元的 right_out 會是其右方處理單元的 left_in,同理可得該處理單元的 right_in 會是其左方處理單元的 left_out。處理單元每個週期會進行一次計算和一次資料搬移,當處理單元一多,感覺起來就像是一起脈動一樣,這即是脈動陣列名字的由來,在實作脈動陣列時,我們需要保證在計算完成前資料不會被新的值改寫。

# Copyright (c) 2023-2024 Chung-Yi Chen (Yeecy)
# Data Movement of the Dice Sum Processing Element

def move_data(left_in, right_in, left_out, right_out, stall):
if not stall: left_out = right_in
right_out = left_in

stall 為一個布林值,之後會提到其用處。

現在的問題只剩下要如何決定 PE1 的 left_in 和 PE6 的 right_in 了。根據不同的策略,我們的資料流動方式會有所不同,基本的策略有兩種,第一種是 TPU 中所使用的權重靜止(weight stationary)策略,第二種是本文矩陣加速單元所用的輸出靜止(output stationary)策略,從實作複雜度來說,輸出靜止應該較為簡單,不過本例因為卷積的性質只能使用權重靜止。

我們首先將 f 的值推入處理單元裡,同時將 0 推入 g 的來源 PE1 的 left_in。當週期為 0 時,我們將 f[1] 會推入 PE6 的 right_in,確保在週期為 1 時 PE6 的 right_in 會是 f[1],同時在週期為 1 時,我們會將 f[2] 推入 PE6 的 right_in,依此類推,過程演示如下圖。

當週期為 6 時,我們將前面的 stall 設為 1,使得 f 的值停止流動,固定在處理單元裡,此時 f[1] 到 f[6] 各自散布在相應的處理單元,可以準備開始計算卷積。

與前面 f 的傳遞過程一樣,當週期為 7 時,我們將 g[1] 會推入 PE1 的 left_in,在週期為 8 時,將 g[2] 會推入 PE1 的 left_in,依此類推,過程演示如下,在這一過程中我們將能夠得到各個點數和的出現機率。

從過程推導中可以看到從第 0 個週期開始到第 18 個週期得到最後的 (f * g)[12] 為止,我們一共用了 19 個週期得到點數和為 2 到 12 的的出現機率,細心的讀者可能已經發現倘若在週期 6 時就將 g[1] 推入 PE1 的話,還能將所需週期減少至 18,如果再狠一點,我們可以額外拉線用一個週期把 f[1] 到 f[6] 放進相應的處理單元裡,那麼最後只需要 12 個週期即可。

看完這個範例,相信讀者應該能感覺到脈動陣列的設計精神了,脈動陣列只需簡單的硬體即可有效計算出答案,並且具有高度的可擴展性,假設今天需要能夠計算兩個二十面骰點數和機率的硬體,那麼我們只需把處理單元擴充到到二十個即可,不需要重新設計。

4.2 矩陣乘法的脈動陣列設計

在看過前面一維的範例後,我們來看如何用二維的脈動陣列來實現矩陣乘法。為了簡單起見,這裡會以 3×3 的脈動陣列作為範例,演示 3×3 的脈動陣列要如何計算矩陣乘法,雖說這個例子相當簡單,但我們可以很容易地把 3×3 的脈動陣列擴展成更大的脈動陣列。

為了方便閱讀,後續如果提到「mαnβkγ」的矩陣乘法,表示這個矩陣乘法的左矩陣和右矩陣大小分別為 α×γγ×β,且算出來的矩陣大小為 α×β

如前面所提,此處我們會用輸出靜止的形式實作脈動陣列,也就是說輸出矩陣的值會保存在處理單元內,而兩個要相乘的矩陣的值會源源不斷地傳入處理單元內。

從上一章節出現過的矩陣乘法定義 cₘₙ = aₘ₁×b + aₘ₂×b + … + aₘₖ×bₖₙ 可以知道 cₘₙA 中的第 m 列元素構成的向量與 B 中的第 n 行元素構成的向量的內積之值,所以對於處理單元來說,必須要能完成自身對應位置的向量內積。

我們現在來看用於矩陣乘法的處理單元長怎樣,因為處理單元要計算 A 的某一列,所以我們讓來自 A 的值從左至右水平流動,同理因為要計算 B 的某一行,所以讓 B 的值從上至下垂直流動,每個周期處理單元會計算 lefttop 的乘積,並將其累加儲存在 acc 裡;如同前面一維的例子,除了計算外,處理單元也會將值傳遞給鄰居,也就是將 lefttop 分別寫到 rightbottom 裡。

                 top
|
PE ⌄
+---------------+
| top: 0 |
left ---> | left: 0 | ---> right
| acc: 0 |
+---------------+
|

bottom

接下來給出 3×3 脈動陣列的設計,因為採用輸出靜止的形式,我們可以預期矩陣乘法的結果跟脈動陣列的放置形式一樣,舉例來說,PE12 在計算完成後其 acc 之值即為 c₁₂。

                bk1                    bk2                    bk3
| | |
PE11 ⌄ PE12 ⌄ PE13 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: 0 |
a1k ---> | left: 0 | ---> | left: 0 | ---> | left: 0 |
| acc: 0 | | acc: 0 | | acc: 0 |
+---------------+ +---------------+ +---------------+
| | |
PE21 ⌄ PE22 ⌄ PE23 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: 0 |
a2k ---> | left: 0 | ---> | left: 0 | ---> | left: 0 |
| acc: 0 | | acc: 0 | | acc: 0 |
+---------------+ +---------------+ +---------------+
| | |
PE31 ⌄ PE32 ⌄ PE33 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: 0 |
a3k ---> | left: 0 | ---> | left: 0 | ---> | left: 0 |
| acc: 0 | | acc: 0 | | acc: 0 |
+---------------+ +---------------+ +---------------+

現在定義兩個 3×3 矩陣與相乘結果如下,如果有讀者對於 a1k 等等的標示感到疑惑,那麼觀察下面 AB 中各個元素的在 C 中的出現規律應該能夠有所幫助。

Matrix A          Matrix B
----------- -----------
[r] [s] [t] [α] [β] [γ]
[u] [v] [w] [δ] [ϵ] [ζ]
[x] [y] [z] [η] [θ] [ι]

Matrix C
--------------------------------
[rα+sδ+tη] [rβ+sϵ+tθ] [rγ+sζ+tι]
[uα+vδ+wη] [uβ+vϵ+wθ] [uγ+vζ+wι]
[xα+yδ+zη] [xβ+yϵ+zθ] [xγ+yζ+zι]

為了保證計算結果正確,我們需要對原本矩陣的值做些排列,舉例來說,我們不能同時將 a₁₁a₂₁a₃₁ 推入陣列裡,原因很簡單,按照定義在 c₁₁、c₁₂ 和 c₁₃ 中 a₁₁、a₂₁ 和 a₃₁ 都需要跟 b₁₁ 相乘,但顯然 b₁₁ 只會同時在一個處理單元中出現,所以若同時推入 a₁₁、a₂₁ 和 a₃₁,最終所得的計算結果必然錯誤,要解決這個問題有兩個方法,一是錯開三者的推入時間,也就是待會會看到的處理方法,二是讓 b₁₁ 同時出現在 PE11、PE21 和 PE31,但如果這麼做硬體複雜度將會顯著提升,而且有悖脈動陣列的設計精神。

在開始計算前,我們應該將處理單元內的所有暫存器,也就是 topleftacc 都重設為 0,保證計算結果正確,同時把 a₁₁ 和 b₁₁ 推入 PE11 內,其餘部分推入 0,節省一個週期的時間。

cycle = 0, reset
----------------------------------------------------------------------
α 0 0
| | |
PE11 ⌄ PE12 ⌄ PE13 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: 0 |
r ---> | left: 0 | ---> | left: 0 | ---> | left: 0 |
| acc: 0 | | acc: 0 | | acc: 0 |
+---------------+ +---------------+ +---------------+
| | |
PE21 ⌄ PE22 ⌄ PE23 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: 0 |
0 ---> | left: 0 | ---> | left: 0 | ---> | left: 0 |
| acc: 0 | | acc: 0 | | acc: 0 |
+---------------+ +---------------+ +---------------+
| | |
PE31 ⌄ PE32 ⌄ PE33 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: 0 |
0 ---> | left: 0 | ---> | left: 0 | ---> | left: 0 |
| acc: 0 | | acc: 0 | | acc: 0 |
+---------------+ +---------------+ +---------------+

接著把 a₁₂ 和 b₂₁ 推入 PE11、a₂₁ 推入 PE21、b₁₂ 推入 PE12 內,此時有一個處理單元進行有意義的運算,得到脈動陣列的使用率為 1/9。

cycle = 1, utilization = 1/9
----------------------------------------------------------------------
δ β 0
| | |
PE11 ⌄ PE12 ⌄ PE13 ⌄
+---------------+ +---------------+ +---------------+
| top: α | | top: 0 | | top: 0 |
s ---> | left: r | ---> | left: 0 | ---> | left: 0 |
| acc: rα | | acc: 0 | | acc: 0 |
+---------------+ +---------------+ +---------------+
| | |
PE21 ⌄ PE22 ⌄ PE23 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: 0 |
u ---> | left: 0 | ---> | left: 0 | ---> | left: 0 |
| acc: 0 | | acc: 0 | | acc: 0 |
+---------------+ +---------------+ +---------------+
| | |
PE31 ⌄ PE32 ⌄ PE33 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: 0 |
0 ---> | left: 0 | ---> | left: 0 | ---> | left: 0 |
| acc: 0 | | acc: 0 | | acc: 0 |
+---------------+ +---------------+ +---------------+

按照同樣的思路,把 a₁₃ 和 b₃₁ 推入 PE11、a₂₂ 推入 PE21、b₂₂ 推入 PE12、a₃₁ 推入 PE31、b₁₃ 推入 PE13 內,此時有三個計算單元進行有意義的運算,得到脈動陣列的使用率為 3/9。

cycle = 2, utilization = 3/9
----------------------------------------------------------------------
η ϵ γ
| | |
PE11 ⌄ PE12 ⌄ PE13 ⌄
+---------------+ +---------------+ +---------------+
| top: δ | | top: β | | top: 0 |
t ---> | left: s | ---> | left: r | ---> | left: 0 |
| acc: rα+sδ | | acc: rβ | | acc: 0 |
+---------------+ +---------------+ +---------------+
| | |
PE21 ⌄ PE22 ⌄ PE23 ⌄
+---------------+ +---------------+ +---------------+
| top: α | | top: 0 | | top: 0 |
v ---> | left: u | ---> | left: 0 | ---> | left: 0 |
| acc: uα | | acc: 0 | | acc: 0 |
+---------------+ +---------------+ +---------------+
| | |
PE31 ⌄ PE32 ⌄ PE33 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: 0 |
x ---> | left: 0 | ---> | left: 0 | ---> | left: 0 |
| acc: 0 | | acc: 0 | | acc: 0 |
+---------------+ +---------------+ +---------------+

現在因為 A 的第一列和 B 的第一行都已經送入陣列中,因此需要推入 0,其餘邏輯與前面無異,讀者可以自行推論,此時有六個計算單元進行有意義的運算,得到脈動陣列的使用率為 6/9。

cycle = 3, utilization = 6/9
----------------------------------------------------------------------
0 θ ζ
| | |
PE11 ⌄ PE12 ⌄ PE13 ⌄
+---------------+ +---------------+ +---------------+
| top: η | | top: ϵ | | top: γ |
0 ---> | left: t | ---> | left: s | ---> | left: r |
| acc: rα+sδ+tη | | acc: rβ+sϵ | | acc: rγ |
+---------------+ +---------------+ +---------------+
| | |
PE21 ⌄ PE22 ⌄ PE23 ⌄
+---------------+ +---------------+ +---------------+
| top: δ | | top: β | | top: 0 |
w ---> | left: v | ---> | left: u | ---> | left: 0 |
| acc: uα+vδ | | acc: uβ | | acc: 0 |
+---------------+ +---------------+ +---------------+
| | |
PE31 ⌄ PE32 ⌄ PE33 ⌄
+---------------+ +---------------+ +---------------+
| top: α | | top: 0 | | top: 0 |
y ---> | left: x | ---> | left: 0 | ---> | left: 0 |
| acc: xα | | acc: 0 | | acc: 0 |
+---------------+ +---------------+ +---------------+

接下來的過程一併列出,讀者可以跟著追蹤一下數值的流動狀況。

cycle = 4, utilization = 7/9
----------------------------------------------------------------------
0 0 ι
| | |
PE11 ⌄ PE12 ⌄ PE13 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: θ | | top: ζ |
0 ---> | left: 0 | ---> | left: t | ---> | left: s |
| acc: rα+sδ+tη | | acc: rβ+sϵ+tθ | | acc: rγ+sζ |
+---------------+ +---------------+ +---------------+
| | |
PE21 ⌄ PE22 ⌄ PE23 ⌄
+---------------+ +---------------+ +---------------+
| top: η | | top: ϵ | | top: γ |
0 ---> | left: w | ---> | left: v | ---> | left: u |
| acc: uα+vδ+wη | | acc: uβ+vϵ | | acc: uγ |
+---------------+ +---------------+ +---------------+
| | |
PE31 ⌄ PE32 ⌄ PE33 ⌄
+---------------+ +---------------+ +---------------+
| top: δ | | top: β | | top: 0 |
z ---> | left: y | ---> | left: x | ---> | left: 0 |
| acc: xα+yδ | | acc: xβ | | acc: 0 |
+---------------+ +---------------+ +---------------+

cycle = 5, utilization = 6/9
----------------------------------------------------------------------
0 0 0
| | |
PE11 ⌄ PE12 ⌄ PE13 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: ι |
0 ---> | left: 0 | ---> | left: 0 | ---> | left: t |
| acc: rα+sδ+tη | | acc: rβ+sϵ+tθ | | acc: rγ+sζ+tι |
+---------------+ +---------------+ +---------------+
| | |
PE21 ⌄ PE22 ⌄ PE23 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: θ | | top: ζ |
0 ---> | left: 0 | ---> | left: w | ---> | left: v |
| acc: uα+vδ+wη | | acc: uβ+vϵ+wθ | | acc: uγ+vζ |
+---------------+ +---------------+ +---------------+
| | |
PE31 ⌄ PE32 ⌄ PE33 ⌄
+---------------+ +---------------+ +---------------+
| top: η | | top: ϵ | | top: γ |
0 ---> | left: z | ---> | left: y | ---> | left: x |
| acc: xα+yδ+zη | | acc: xβ+yϵ | | acc: xγ |
+---------------+ +---------------+ +---------------+

cycle = 6, utilization = 3/9
----------------------------------------------------------------------
0 0 0
| | |
PE11 ⌄ PE12 ⌄ PE13 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: 0 |
0 ---> | left: 0 | ---> | left: 0 | ---> | left: 0 |
| acc: rα+sδ+tη | | acc: rβ+sϵ+tθ | | acc: rγ+sζ+tι |
+---------------+ +---------------+ +---------------+
| | |
PE21 ⌄ PE22 ⌄ PE23 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: ι |
0 ---> | left: 0 | ---> | left: 0 | ---> | left: w |
| acc: uα+vδ+wη | | acc: uβ+vϵ+wθ | | acc: uγ+vζ+wι |
+---------------+ +---------------+ +---------------+
| | |
PE31 ⌄ PE32 ⌄ PE33 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: θ | | top: ζ |
0 ---> | left: 0 | ---> | left: z | ---> | left: y |
| acc: xα+yδ+zη | | acc: xβ+yϵ+zθ | | acc: xγ+yζ |
+---------------+ +---------------+ +---------------+

cycle = 7, utilization = 1/9
----------------------------------------------------------------------
0 0 0
| | |
PE11 ⌄ PE12 ⌄ PE13 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: 0 |
0 ---> | left: 0 | ---> | left: 0 | ---> | left: 0 |
| acc: rα+sδ+tη | | acc: rβ+sϵ+tθ | | acc: rγ+sζ+tι |
+---------------+ +---------------+ +---------------+
| | |
PE21 ⌄ PE22 ⌄ PE23 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: 0 |
0 ---> | left: 0 | ---> | left: 0 | ---> | left: 0 |
| acc: uα+vδ+wη | | acc: uβ+vϵ+wθ | | acc: uγ+vζ+wι |
+---------------+ +---------------+ +---------------+
| | |
PE31 ⌄ PE32 ⌄ PE33 ⌄
+---------------+ +---------------+ +---------------+
| top: 0 | | top: 0 | | top: ι |
0 ---> | left: 0 | ---> | left: 0 | ---> | left: z |
| acc: xα+yδ+zη | | acc: xβ+yϵ+zθ | | acc: xγ+yζ+zι |
+---------------+ +---------------+ +---------------+

我們可以看到脈動陣列只需花八個週期即可得到 m3n3k3 矩陣乘法的結果,並且跟手動算出的結果一致。

Matrix A          Matrix B
----------- -----------
[r] [s] [t] [α] [β] [γ]
[u] [v] [w] [δ] [ϵ] [ζ]
[x] [y] [z] [η] [θ] [ι]

Matrix C
--------------------------------
[rα+sδ+tη] [rβ+sϵ+tθ] [rγ+sζ+tι]
[uα+vδ+wη] [uβ+vϵ+wθ] [uγ+vζ+wι]
[xα+yδ+zη] [xβ+yϵ+zθ] [xγ+yζ+zι]

為了方便讀者理解,這裡給出脈動陣列輸入隨週期變化的過程,可以跟上面兩個矩陣相互對照。

Input                                   | bk1 bk2 bk3 |
+---------------------------------+-------------+
| Cycle 7 | 0 0 0 |
| 6 | 0 0 0 |
| 5 | 0 0 0 |
| 4 | 0 0 ι |
| 3 | 0 θ ζ |
| 2 | η ϵ γ |
| 1 | δ β 0 |
| 7 6 5 4 3 2 1 0 | α 0 0 |
------+---------------------------------+-------------+
a1k | 0 0 0 0 0 t s r | 3×3 |
a2k | 0 0 0 0 w v u 0 | Systolic |
a3k | 0 0 0 z y x 0 0 | Array |
------+---------------------------------+-------------+

從前面的過程中可以看到 3×3 的脈動陣列在計算 m3n3k3 的矩陣乘法時,其平均使用率為 ((1 + 3 + 6 + 7 + 6 + 3 + 1) ÷ 9) ÷ 8 = 0.375,也就是每個週期平均只有 3.375 個處理單元在進行有意義的計算。

      Cycle      |  7   6   5   4   3   2   1   0
-----------------+---------------------------------
# Used PEs | 1 3 6 7 6 3 1 0
-----------------+---------------------------------
# Pushed Inputs | 0 0 0 2 4 6 4 2
-----------------+---------------------------------
Stage | <--- 4 ---> <- 3 -> <2> <- 1 ->

接下來我們來算一下週期數,在上面 m3n3k3 矩陣乘法裡根據輸入情況可以把計算拆成四個階段,第一階段 <- 1 -> 為剛開始將輸入資料放進去的情況,此階段需要的週期數為 3 - 1 = 2;第二階段 <2> 為所有方向都被推入有意義的值的狀況,此階段的週期數為 1;第三階段 <- 3 -> 跟第一階段類似,需要的週期數為 3 - 1 = 2;第四階段 <--- 4 ---> 為等待計算完成的情況,需要的週期數為 3,準確認知到脈動陣列的週期數對於正確實作脈動陣列來說相當重要,若將 3×3 泛化到 s×s 的脈動陣列運行 msnsks 的矩陣乘法,我們可以期待總共需要的週期數為 (s - 1) + 1 + (s - 1) + s = 3s - 1。

前面的矩陣乘法剛好跟脈動陣列的大小相同,那麼如果要相乘的兩個矩陣比 3×3 脈動陣列本身還小,那麼脈動陣列有辦法計算答案嗎?答案是可以,只需把兩個矩陣用 0 填充,這樣原本的矩陣乘法就能被轉換為 m3n3k3 的矩陣乘法,最後計算完根據應有的輸出大小取出對應位置的值即可。

不過這裡有個需要注意的地方,那就是這個方法會浪費額外的週期。怎麼說呢?如果筆者沒算錯的話,令 s ≥ max(α, β, γ),則使用輸出靜止策略的 s×s 脈動陣列在計算 mαnβkγ 矩陣乘法所需要的週期數為 α + β + γ - 1,若我們將其擴展變成 msnsks 的矩陣乘法,將會浪費 3s - α - β - γ 個週期。

s 比 max(α, β, γ) 大很多時,浪費的週期數累積起來不容小覷,用筆者實際實作的 16×16 脈動陣列來說,若要計算的是 m1n1k1 的矩陣乘法,帶入上式知道共會浪費 45 個週期,而實際計算只有兩個週期,如果不打算加上額外的判斷電路來得知所需的計算結果在何時已經被算出的話,在加大脈動陣列時,需要考慮實際運算遇到這種情況的頻率。

這裡討論了矩陣乘法不比脈動陣列大的情況,那麼當矩陣比脈動陣列還大時該怎麼辦呢?直覺來講,我們需要在軟體層面運用前面提過的分塊矩陣乘法,來把矩陣切分成能放進脈動陣列運算的大小,不過還記得前面提到的使用率問題嗎?用 3×3 的情況來看,在矩陣不大於脈動陣列時,即使可以透過某些方法來提前結束運算,我們依然無法取得比 m3n3k3 矩陣乘法還好的平均使用率(為什麼呢?),所以唯一的機會只剩下當矩陣比脈動陣列大的情況。

綜合這兩點考量,我們需要兩層的分塊矩陣乘法,第一次分塊由軟體完成,而第二次分塊由硬體完成,為了達到硬體分塊,接下來會引入儲存矩陣的暫存區,以便於我們在硬體上實現以脈動陣列為切分大小的的分塊矩陣乘法,最後在這基礎上嘗試能否增加脈動陣列使用率!

這裡給出處理單元對應的程式碼,不過處理單元之間的線具體要怎麼接就交給讀者研究了,至於 leftright 為何用了九個位元,我們之後會揭曉答案。

// Copyright (c) 2023-2024 Chung-Yi Chen (Yeecy)
// Implementation of the Processing Element Targeting TFLite
// int8 Quantization

module ProcessingElement(
input clk,
input rst,
input [8:0] left,
input [7:0] top,
output reg [8:0] right,
output reg [7:0] down,
output reg [31:0] acc
);

always @(posedge clk) begin
if (rst) begin
right <= 0; down <= 0; acc <= 0;
end else begin
acc <= $signed(acc) + $signed(left)*$signed(top);
right <= left;
down <= top;
end
end

4.3 暫存區設計

暫存區的大小受限於 FPGA 晶片內的儲存單元數量,因為有三個矩陣 ABC 需要儲存,其中 AB 的元素為 int8,C 的元素為 int32(這是為了滿足 TFLite int8 量化的需求),經過嘗試,在筆者使用的 Arty A7–100T 上最大能支援 m256n256k256 的矩陣乘法,也就是說三個暫存區的總大小為 (1 + 1 + 4) × 256 × 256 = 384 KiB,讀者若要在 FPGA 上實作,可能需要自行縮小矩陣乘法加速單元能夠支援的矩陣大小。

為了方便舉例,此處假設我們的矩陣加速單元支援 m4n4k4 的矩陣乘法,且使用的脈動陣列大小為 2×2,如果讀者尚未完全理解前面的脈動陣列設計,建議先讀懂了再回到這裡繼續閱讀。

下面是我們接下來討論會用到的三個矩陣,CAB 的乘積,因為這裡關心的是資料的讀取和儲存,所以 C 的元素以單一符號表示而不是 AB 的元素組合。

Matrix A             Matrix B             Matrix C
--------------- --------------- ---------------
[a] [b] [c] [d] [α] [β] [γ] [δ] [a] [b] [c] [d]
[e] [f] [g] [h] [ϵ] [ζ] [η] [θ] [e] [f] [g] [h]
[i] [j] [k] [l] [ι] [κ] [λ] [μ] [i] [j] [k] [l]
[m] [n] [o] [p] [ν] [ξ] [ο] [π] [m] [n] [o] [p]

現在我們需要決定暫存區的資料布局,也就是資料會如何放置在暫存區裡,這一點相當重要,不良的設計將會導致資料的讀取和寫回邏輯變複雜,更甚者可能要花更多週期,使得矩陣乘法需要花更多時間才能完成。

最簡單的方法就是把矩陣直接塞進暫存區裡,假設我們想要得到 b₁₁,那麼只需存取 BufferB[0][31:24] 即可,不過這會有個問題,考慮到一個位址對應到四個元素,我們如果一口氣取出四個元素,因為脈動陣列為 2×2 的緣故,實際上只會使用兩個元素,浪費一半的讀取頻寬;如果分兩次取出元素,則會多浪費一個週期;如果一次只取兩個需要的元素,在硬體上又會引入額外的電路來決定取出的是 BufferB[i][31:16] 還是 BufferB[i][15:0]。另外如果當今天矩陣的大小為 6×6,那麼又該怎麼塞入一個位址只能存放四個元素的暫存區呢?如果讓一個位址能儲存六個元素,那麼前面三種情況又該怎麼解決呢?

 BufferA                   BufferB                   BufferC
+------------------+ +------------------+ +------------------+
| 0: a b c d | | 0: α β γ δ | | 0: a b c d |
| 1: e f g h | | 1: ϵ ζ η θ | | 1: e f g h |
| 2: i j k l | | 2: ι κ λ μ | | 2: i j k l |
| 3: m n o p | | 3: ν ξ ο π | | 3: m n o p |
+------------------+ +------------------+ +------------------+

既然如此,不妨把一個位址對應到的資料長度從原本的矩陣大小改為脈動陣列的大小,也就是說如果脈動陣列大小為 n×n,那麼就讓暫存區一個位址能儲存的元素數量為 n,至於要怎麼把原本的矩陣儲存在暫存區呢?我們可以規定優先把矩陣每一列(讀者想用行也可以)的前 n 個元素按順序放到暫存區裡,接著再把每一列的第 n 到 2n — 1 個元素按順序放到暫存區裡,依此類推直到整個矩陣都被放進暫存區。當然為了別讓電路太過複雜,軟體應該要負責將資料處理成暫存區要的樣子才傳入硬體,而非將原矩陣傳入硬體後,再交由硬體電路把原矩陣處理後儲存到正確的位址裡。

 BufferA             BufferB             BufferC
+------------+ +------------+ +------------+
| 0: a b | | 0: α β | | 0: a b |
| 1: e f | | 1: ϵ ζ | | 1: e f |
| 2: i j | | 2: ι κ | | 2: i j |
| 3: m n | | 3: ν ξ | | 3: m n |
| 4: c d | | 4: γ δ | | 4: c d |
| 5: g h | | 5: η θ | | 5: g h |
| 6: k l | | 6: λ μ | | 6: k l |
| 7: o p | | 7: ο π | | 7: o p |
+------------+ +------------+ +------------+

不過這就完了嗎?我們先看看 m2n2k2 矩陣乘法在 2×2 脈動陣列上輸入隨週期變化的圖,有些讀者可能已經發現從 BufferA 和 BufferB 取出資料後重新排列的邏輯不一樣,說得具體一點,BufferA[0] 全都會從 a1k 進入脈動陣列裡,但 BufferB[0] 卻會從 bk1 和 bk2 分別進入脈動陣列,如果要分別實現兩種邏輯的話,除錯難度可能會很高,所以我們可以試著重新排列矩陣讓 BufferA 和 BufferB 取出資料後的排列邏輯相同,使得實作更容易一些。

Input                       | bk1 bk2 |
+---------------------+---------+
| Cycle 4 | 0 0 |
| 3 | 0 θ |
| 2 | 0 ζ |
| 1 | ϵ β |
| 4 3 2 1 0 | α 0 |
------+---------------------+---------+
a1k | 0 0 0 b a | |
a2k | 0 0 f e 0 | 2×2 SA |
------+---------------------+---------+

問題來了,該怎麼重新排列呢?這裡採用的方法是將軟體將 A 轉置後按前面的儲存方法儲存,這樣一來我們可以看到 BufferA[0]BufferB[0] 一樣都會分別從兩個通道進入脈動陣列裡,而且計算出來的矩陣不會受到影響,倘若我們選擇將 B 轉置的話,會發現計算完後脈動陣列儲存的值需要轉置才會是預期的 C,那麼轉置 A 就成為非常合理的選擇了,當然讀者想轉置 B 或者都不轉置也可以,因為所有的選擇都有其代價,我們稍後會討論代價是什麼。

 BufferA             BufferB             BufferC
+------------+ +------------+ +------------+
| 0: a e | | 0: α β | | 0: a b |
| 1: b f | | 1: ϵ ζ | | 1: e f |
| 2: c g | | 2: ι κ | | 2: i j |
| 3: d h | | 3: ν ξ | | 3: m n |
| 4: i m | | 4: γ δ | | 4: c d |
| 5: j n | | 5: η θ | | 5: g h |
| 6: k o | | 6: λ μ | | 6: k l |
| 7: l p | | 7: ο π | | 7: o p |
+------------+ +------------+ +------------+

如果說傳入的矩陣大小比暫存區設計的 4×4 還小呢?按照之前脈動陣列的思路就好,也就是簡單地在軟體上將矩陣用 0 擴充到 4×4 後按前面的規則寫入暫存區,下面的例子供讀者參考。

Matrix D          Buffer
----------- +------------+
[r] [s] [t] | 0: r s |
[u] [v] [w] | 1: u v |
[x] [y] [z] | 2: x y |
| 3: 0 0 |
| 4: t 0 |
| 5: w 0 |
| 6: z 0 |
| 7: 0 0 |
+------------+

上面的 Buffer[3]Buffer[7] 可以不須被儲存,也就是說暫存區長下面這樣,不過屆時算位址取資料時需要注意這點,當電路上發現越界時需要自行補 0 推入脈動陣列中。

 Buffer
+------------+
| 0: r s |
| 1: u v |
| 2: x y |
| 3: t 0 |
| 4: w 0 |
| 5: z 0 |
| 6: - - |
| 7: - - |
+------------+

最後我們來談談代價,更詳細點說是傳輸 AB 的代價,假設可以用一條指令填滿一個位址的值,那麼轉置與否將會影響所需的總指令數,假設要做的是 mαnβkγ 的矩陣乘法,且脈動陣列的大小為 s×s,令 round(x) 表示不小於 x 且為 s 的倍數的所有值中的最小值。

在沒做上面的暫存區最佳化的前提下,不轉置、轉置 A 和轉置 B 三者都需要 round(α) × round(γ) ÷ s + round(γ) × round(β) ÷ s 條指令才能完成,也就是說轉置 A 是最為理想的方法,然而在做了上面的暫存區最佳化的前提下,三者所需要的指令數如下所示,顯然在不同的 αβγ 的情況下,三者所需的指令數有可能不同。

  • 不轉置:(α × round(γ) + γ × round(β)) ÷ s
  • 轉置 Aγ × (round(α) + round(β)) ÷ s
  • 轉置 B:(α + β) × round(γ) ÷ s

觀察上面的式子容易看出如果 αβγ 都為 s 的倍數,那麼三者所需的指令數相同;如果只有 α 不為 s 的倍數,那麼不轉置和轉置 B 的指令數會一樣且少於轉置 A;如果只有 β 不為 s 的倍數,那麼不轉置和轉置 A 的指令數會一樣且多於轉置 B;如果只有 γ 不為 s 的倍數,那麼轉置 A 的指令數最少,轉置 B 的指令數最多,至於其餘情況需要看具體的數值才能確定,讀者可以按照實際情況代入不同的值看看結果。

從以上分析可以看到轉置 B 貌似更好,但從實作角度來說轉置 A 應該是比較不易出錯的方法,讀者可以在實作出轉置 A 後再改成轉置 B 試試看,如果表現真的較好的話,讀者可以在底下留言告知。

最後給出暫存區的程式碼,因為 AB 的大小為 4×4,所以只需四個位元作為位址值,暫存區每一列會儲存兩個 int8,所以會需要 16 個位元來儲存,另外因為 C 的元素為 int32,故 C 暫存區的每一列需要 64 個位元,具體情況請按照讀者需求修改。為了保證上一週期送出的資料存取要求能在下一個週期前完成,此處採用負緣觸發,不過如果半個週期真的做不完的話,改為正緣觸發也可以,但後續給出的程式碼讀者需自行修改相應邏輯以保證行為正確。

// Copyright (c) 2023-2024 Chung-Yi Chen (Yeecy)
// Implementation of the Buffer for a 4×4 int8 Matrix

module Buffer(
input clk,
input wr_en,
input [3:0] addr,
input [15:0] data_in,
output reg [15:0] data_out
);

reg [15:0] gbuff [15:0];

always @(negedge clk) begin
if (wr_en) begin
gbuff[addr] <= data_in;
end else begin
data_out <= gbuff[addr];
end
end

endmodule

4.4 脈動陣列改進

前一小節定義了暫存區的格式,是時候來演示如何在分塊矩陣乘法的基礎上提高脈動陣列的使用率了。

讓我們關注前一小節中 C 的第一個分塊會用到 AB 中的那些元素,因為是 m4n4k4 的矩陣乘法,對於 C 的第一個分塊來說,會需要兩次 m2n2k2 的矩陣乘法才能得到答案。

Matrix A (Partial)      Matrix B (Partial)      Matrix C (Parital)
------------------ ------------------ ------------------
[a] [b] [c] [d] [α] [β] [a] [b]
[e] [f] [g] [h] [ϵ] [ζ] [e] [f]
[ι] [κ]
[ν] [ξ]

按照前面脈動陣列的做法,兩次 m2n2k2 矩陣乘法輸入隨週期變化的過程如下,為了簡單起見,這裡省略存取暫存區所需要的時間,可以看到第 0 到第 4 個週期在計算第一個分塊矩陣乘法,而第 5 到第 9 個週期在計算第二個分塊矩陣乘法。

Input                                           | bk1 bk2 |
+-----------------------------------------+---------+
| Cycle 9 | 0 0 |
| 8 | 0 0 |
| 7 | 0 ξ |
| 6 | ν κ |
| 5 | ι 0 |
| 4 | 0 0 |
| 3 | 0 θ |
| 2 | 0 ζ |
| 1 | ϵ β |
| 9 8 7 6 5 4 3 2 1 0 | α 0 |
------+-----------------------------------------+---------+
a1k | 0 0 0 d c 0 0 0 b a | |
a2k | 0 0 h g 0 0 0 f e 0 | 2×2 SA |
------+-----------------------------------------+---------+
<- 4 -> <3> <2> <1> <- 4 -> <3> <2> <1>

原先做法在第 0 和 5 個週期會將全部處理單元的 lefttopacc 都重置為 0,不過既然兩次分塊矩陣乘法的目的分塊都是同一個,表示第二次乘法不須重置對吧?可以直接在上一次的基礎上累加就好。有了這個觀察,不妨將兩次矩陣乘法縮減為一次,也就是說將兩次的輸入直接拼接起來而不補 0。

Input                               | bk1 bk2 |
+-----------------------------+---------+
| Cycle 6 | 0 0 |
| 5 | 0 0 |
| 4 | 0 ξ |
| 3 | ν κ |
| 2 | ι ζ |
| 1 | ϵ β |
| 6 5 4 3 2 1 0 | α 0 |
------+-----------------------------+---------+
a1k | 0 0 0 d c b a | |
a2k | 0 0 h g f e 0 | 2×2 SA |
------+-----------------------------+---------+
<- 4 -> <3> <--- 2 ---> <1>

這樣的方法正確嗎?我們可以觀察推入資料在脈動陣列裡的停留時間來得知這點。第一次分塊矩陣乘法被推入的第一個元素是 a,會在第 1 個週期進入脈動陣列裡(第 0 個週期是準備進入),在第 3 個週期離開脈動陣列,而第二次分塊矩陣乘法被推入的第一個元素是 c ,會在第 3 個週期進入脈動陣列裡,依此類推,可以知道不同次分塊矩陣乘法的元素不會在脈動陣列裡相遇,因此結果正確。

這樣的做法其實等價於將兩次 m2n2k2 的矩陣乘法合併為一次 m2n2k4 的矩陣乘法,在上面的例子裡一共省下了三個週期,透過合併,我們將第一次矩陣乘法的 <3> 跟第二次矩陣乘法的 <1> 濃縮成一次 <2>,並且節省了第一次矩陣乘法中用來等待計算結果的 <4>,一共三個週期,讀者應該不難看出當暫存區能容納的矩陣越大時,相對於原本一次次分開計算的方法,將目標分塊相同的矩陣乘法合併起來能夠節省越多時間,也就是說分塊越多節省越多。

按此算法可知假設一次暫存區存取的大小為 s 個元素且脈動陣列大小為 s×s,令 c 為正整數,那麼對於 msnskcs 的矩陣乘法來說,原始方法需要 c(3s — 1) = (3s — 1)c 個週期完成,而這裡的改進只需要 (s — 1) + ((c — 1)s + 1) + (s — 1) + s = sc + 2s — 1 個週期,也就是說可以節省 (2s — 1)c — 2s + 1 個週期且節省的週期數正比於 c

如果讀者沒看出來節省的週期數有多誇張的話,我們可以看看改進後的加速比如何,如果 c 遠大於 s 時,加速比近似於 3,也就是平均來說原方法做一次矩陣乘法的時間夠改進方法做三次矩陣乘法;如果用前面的 s = 2 來說,加速比為 5c / (2c + 3),當 c = 6 時就能達到兩倍的加速,且當 c 越大時加速比會越接近 2.5。

總而言之,透過上面的合併策略我們成功用更少的週期數完成所需的計算,也就是說我們成功提高了脈動陣列的使用率!

4.5 控制邏輯

現在我們所需的資料都已經存在暫存區內,最後要做的將資料從中取出並餵給脈動陣列運算並將結果寫回,為了讓脈動陣列能夠計算比其更大的矩陣乘法,必然需要在硬體上實作分塊矩陣乘法,所以控制邏輯在概念上可以分成兩個部分,第一個是分塊控制邏輯,第二個是矩陣乘法邏輯,對於矩陣乘法邏輯來說,脈動陣列本身的計算已經在前面詳細解說過,因此這裡會著重於資料的讀取和傳輸。本小節會以Python 或 C/C++ 程式碼描述運作邏輯,方便讀者理解,最後才會以 Verilog 呈現相應的狀態機。

因為暫存區最大能容納 4×4 的矩陣,需要把最大為 m4n4k4 的矩陣乘法切分為數個 m2n2k2 的矩陣乘法,套用前面章節出現過的分塊矩陣乘法概念的話,我們需要實現的電路功能大體如下,別忘了 A 已被轉置以及處理邊界大小沒辦法被 2 整除的情況。

# Copyright (c) 2023-2024 Chung-Yi Chen (Yeecy)
# 2×2-Tiled Matrix Multiplication

assert(max(M, N, k) <= 4)

for m in range(0, M+1, 2):
for n in range(0, N+1, 2):
for k in range(0, K+1, 2):
MM = min(m+2, M) - m
NN = min(n+2, N) - n
KK = min(k+2, K) - k
C[m:m+MM][n:n+NN] += A[k:k+KK][m:m+MM] * B[k:k+KK][n:n+NN]

上面將維度加上 1 的目的是為了達到 round₂(⋅) 的效果,讀者可以試著代入不同的值看看結果。這裡的邏輯比先前在討論分塊矩陣乘法還複雜的原因為我們假設矩陣大小能被分塊大小整除,但因為實際情況不會這麼理想,需要額外處理來確定實際的分塊大小為何。

接下來我們要把 C[m:m+MM][n:n+NN] += A[k:k+KK][m:m+MM] * B[k:k+KK][n:n+NN] 轉換成類似脈動陣列運作的情況,並且引入前面設計的暫存區。這裡會將資料讀取、計算、資料寫回切開成三個部分,重要的程式碼會有註解編號,讀者可以按照編號找到對應的說明。

# Copyright (c) 2023-2024 Chung-Yi Chen (Yeecy)
# 2×2-Tiled Matrix Multiplication with Buffers

assert(max(M, N, k) <= 4)

cnt_M = (M + 1) // 2
cnt_N = (N + 1) // 2

for m in range(cnt_M):
for n in range(cnt_N):
CC = [[0, 0], [0, 0]]

for k in range(0, K, 2):
AA = [[0, 0], [0, 0]] # 1.
BB = [[0, 0], [0, 0]]

for read_i in range(2):
A_index = m*K + k + read_i # 2.
B_index = n*K + k + read_i
AA[read_i] = A_buffer[A_index] if k + read_i < K # 3.
BB[read_i] = B_buffer[B_index] if k + read_i < K

systolic_array_2x2(AA, BB, CC) # 4.

for write_i in range(2):
C_index = n*M + m*2 + write_i # 5.

if C_index < M + n*M: # 6.
C_buffer[C_index] = CC[write_i]
  1. 在電路設計上會是暫存器,把還沒推進脈動陣列的值暫存起來,至於 CC 在電路上會連接到脈動陣列,例如 PE11 的 acc 實際對應到的是 CC[0][0]
  2. 注意到兩個輸入矩陣的第一個維度皆為 k,按照暫存區的儲存方法,我們知道第一列的前兩個值會放在位址 0,第二列的前兩個值會放在位址 1,最後一列的前兩個值會放在位址 k - 1,第一列的第三和第四個值會放在位址 k,依此類推可知第 m 列的第 2k 和第 2k + 1 個值會放在位址 m×K + k,這裡的 k 對應到程式碼裡的 k + read_i,如果有疑惑的話(不可能沒有吧?),下面有個例子可供讀者研究。
  3. 避免越界讀取,若當前位址超出邊界,用 [0, 0] 代替,別忘了補 0 擴充成 m2n2k2 的矩陣乘法並不會得到錯誤的結果。
  4. 脈動陣列運算,實際電路會把取資料、計算和存資料重疊執行以節省週期。
  5. 矩陣 C 的格式也要按照前面的暫存區格式儲存,注意到 C 的大小為 M×N,按照第 2 點提到的計算方法可以推得第 n 列的第 2m 和第 2m + 1 個值會放在位址 n×M + m,這裡的 m 對應到程式碼裡的 m*2 + write_im*2m 從原本的分塊次數還原為具體的索引值,而上面的 k 本身就是索引值,故不需還原。
  6. 避免越界寫入,如果沒越界的話我們需要把對應位置的資料取出累加後再寫回暫存區。

這裡附上 4×4 脈動陣列使用的暫存區例子,不過概念上跟 2×2 的如出一轍,把前面程式碼的 2 改成 4 即可,暫存區旁邊的數字為位址,讀者可以試著手動將下面 m7n9k5 的矩陣乘法分塊成 m4n4k4 的矩陣乘法後,觀察位址跟當前 mnk 之間的關係。

M=7, K=5, N=9
-------------

A.T buffer(A.T)
[5 3 9 1 7 1 8] [5 3 9 1] 0
[4 6 4 2 5 3 3] [4 6 4 2] 1
[3 0 7 1 5 4 7] [3 0 7 1] 2
[2 2 8 9 3 2 9] [2 2 8 9] 3
[6 1 9 8 3 6 1] [6 1 9 8] 4
[7 1 8 0] 5
[5 3 3 0] 6
[5 4 7 0] 7
[3 2 9 0] 8
[3 6 1 0] 9

B buffer(B)
[1 1 2 4 2 1 5 8 1] [1 1 2 4] 0
[1 2 3 1 3 1 5 7 3] [1 2 3 1] 1
[6 4 6 9 8 3 1 5 3] [6 4 6 9] 2
[3 2 4 3 2 1 0 5 4] [3 2 4 3] 3
[4 6 3 1 3 2 1 4 6] [4 6 3 1] 4
[2 1 5 8] 5
[3 1 5 7] 6
[8 3 1 5] 7
[2 1 0 5] 8
[3 2 1 4] 9
[1 0 0 0] 10
[3 0 0 0] 11
[3 0 0 0] 12
[4 0 0 0] 13
[6 0 0 0] 14

C buffer(C)
[ 57 65 66 63 68 32 54 117 70] [ 57 65 66 63] 0
[ 19 25 35 25 31 13 46 80 35] [ 19 25 35 25] 1
[115 115 131 136 129 60 81 211 128] [115 115 131 136] 2
[ 68 75 74 50 58 31 24 104 94] [ 68 75 74 50] 3
[ 63 61 80 90 84 36 68 143 67] [ 63 61 80 90] 4
[ 58 63 61 55 65 30 30 83 66] [ 58 63 61 55] 5
[ 84 66 106 126 102 43 63 169 80] [ 84 66 106 126] 6
[ 68 32 54 117] 7
[ 31 13 46 80] 8
[129 60 81 211] 9
[ 58 31 24 104] 10
[ 84 36 68 143] 11
[ 65 30 30 83] 12
[102 43 63 169] 13
[ 70 0 0 0] 14
[ 35 0 0 0] 15
[128 0 0 0] 16
[ 94 0 0 0] 17
[ 67 0 0 0] 18
[ 66 0 0 0] 19
[ 80 0 0 0] 20

這裡筆者列出當 mn 分別為 1 和 2 的情況,做為讀者自行分塊的參考。

m = 1, n = 2
------------

Read, k = 0
-----------
buffer(A.T) buffer(B)
5 [7 1 8 0] [1 0 0 0] 10
6 [5 3 3 0] [3 0 0 0] 11
7 [5 4 7 0] [3 0 0 0] 12
8 [3 2 9 0] [4 0 0 0] 13

Read, k = 4
-----------
buffer(A.T) buffer(B)
9 [3 6 1 0] [6 0 0 0] 14
x [0 0 0 0] [0 0 0 0] x
x [0 0 0 0] [0 0 0 0] x
x [0 0 0 0] [0 0 0 0] x

Write, k = 0 and 4
------------------
buffer(C)
[ 67 0 0 0] 18
[ 66 0 0 0] 19
[ 80 0 0 0] 20
[ 0 0 0 0] x

眾所周知前面的程式碼無法直接變成硬體描述語言,所以需要將巢狀迴圈輾平成一個迴圈,方便之後翻譯為狀態機。

# Copyright (c) 2023-2024 Chung-Yi Chen (Yeecy)
# 2×2-Tiled Matrix Multiplication with Loop Flattening

assert(max(M, N, k) <= 4)

cnt_M = (M + 1) // 2
cnt_N = (N + 1) // 2

m, n, k, read_i, write_i = 0, 0, 0, 0, 0
AA = [[0, 0], [0, 0]]
BB = [[0, 0], [0, 0]]
CC = [[0, 0], [0, 0]]

while True:
if k < K and read_i < 2:
A_index = m*K + k + read_i
B_index = n*K + k + read_i
AA[read_i] = A_buffer[A_index] if k + read_i < K else [0, 0]
BB[read_i] = B_buffer[B_index] if k + read_i < K else [0, 0]
read_i = read_i + 1

if k < K and read_i == 2:
systolic_array_2x2(AA, BB, CC)
read_i = 0
k = k + 2

if k >= K and write_i < 2:
C_index = n*M + m*2 + write_i

if C_index < M + n*M:
C_buffer[C_index] = CC[write_i]
write_i = write_i + 1

if k >= K and write_i == 2:
CC = [[0, 0], [0, 0]]
k = 0
write_i = 0
n = n + 1

if n == cnt_N:
n = 0
m = m + 1

if m == cnt_M:
break

最後需要的只是把上面的邏輯變換成 Verilog,不過這個部分就交給讀者細細品味了,這裡著重於如何將讀寫資料和脈動陣列運算重疊起來。

回憶先前 m2n2k4 的過程圖,那時我們省略了資料存取所需要的時間,而現在我們要把存取時間也考慮進來。前面提到暫存區是負緣觸發,所以上一週期發起的存取要求在下個週期到來時便保證完成,為了讓讀寫和運算重疊,這個約束條件告訴我們如果可以的話,在上一週期便需要計算好下一週期所需資料的位址並送出需求。

理論上在軟體將暫存區資料準備好後,我們需要透過某種方式告知矩陣乘法加速單元可以開始運作,假設是用一條指令告知,則這條指令應當同時給出矩陣乘法的 MNK 之值,當加速單元收到這條開始運作的指令時,我們可以先預取 A_buffer[0]B_buffer[0],這樣下一個週期(也就是第 0 個週期)加速單元開始運作時我們便能將 aα 準備好送入脈動陣列裡。

接下來我們把每個週期該發生的事情列出如下,因為 AB 兩個矩陣的邏輯一樣,簡潔起見,從第 1 個週期開始只列出 A 相關的事件。

  • 第 0 個週期
    — 預取 A_buffer[1]B_buffer[1]
    A_buffer[0]B_buffer[0] 之值分別存入 AA[0]BB[0]
    A_buffer[0][15:8]B_buffer[0][15:8] 分別推入 a1kbk1,剩下的 a2kbk2 則推入 0
    — 假設存入 AABB 不會在該週期完成
  • 第 1 個週期
    — 脈動陣列開始計算
    — 預取 A_buffer[2]
    A_buffer[1] 存入 AA[1]
    A_buffer[1][15:8]AA[0][7:0] 分別推入 a1ka2k
  • 第 2 個週期
    — 預取 A_buffer[3]
    A_buffer[2] 存入 AA[0](存入 AA[0] 不會蓋掉需要的值嗎?)
    A_buffer[2][15:8]AA[1][7:0] 分別推入 a1ka2k
  • 第 3 個週期
    A_buffer[3] 存入 AA[1]
    A_buffer[3][15:8]AA[0][7:0] 分別推入 a1ka2k
  • 第 4 個週期
    — 0 和 AA[1][7:0] 分別推入 a1ka2k
    CC[0][63:32] 計算完成
  • 第 5 個週期
    CC[0][31:0]CC[1][63:32] 計算完成,將 CC[0] 寫回 C_buffer[0]
  • 第 6 個週期
    CC[1][31:0] 計算完成,將 CC[1] 寫回 C_buffer[1]
    — 重置所有處理單元的暫存器為 0
    — 預取 A_buffer[0]
  • 第 7 個週期
    — 脈動陣列開始計算
    — 預取 A_buffer[1]
    A_buffer[0] 存入 AA[0]
    A_buffer[1][15:8]AA[0][7:0] 分別推入 a1ka2k

從第 7 個週期開始便重複第 0 個週期的邏輯,依此類推。讀者可以看到在上面的流程裡我們成功將資料讀取和計算重疊起來,消除了等待資料讀取的時間,並且在第 6 個週期時將資料寫入和脈動陣列重置重疊起來,消除了等待資料寫回的時間,不過需要注意到最後一次資料寫回需要額外等待一個週期才能結束整個運算,否則結果會出錯。

到此為止我們已經將整個矩陣乘法加速單元的重要部件講解完畢,讀者在讀完本章節的範例後應能自行設計出符合自己需求的矩陣乘法加速單元,最後筆者在此附上利用 2×2 脈動陣列的加速單元狀態機的 Verilog 片段,箇中奧妙就留待讀者細細品嘗,可以確定程式碼還有可以改進的空間,而 offset 跟 TFLite in8 量化的規則有關,其規定在輸入與卷積核相乘之前,輸入值需要先加上一個偏移量。

// Copyright (c) 2023-2024 Chung-Yi Chen (Yeecy)
// Code Snippet of the Matrix Multiplication Unit Targeting TFLite
// int8 Quantization

input in_valid; // active for one cycle
output reg busy; // active while computing
reg pe_rst // reset all processing elements while active

always @(posedge clk) begin
if (in_valid) begin
busy <= 1;
pe_rst <= 1;
AA[0] <= 0; AA[1] <= 0;
BB[0] <= 0; BB[1] <= 0;
A_addr <= 0; B_addr <= 0;
m <= 0; k <= 0; n <= 0;
state <= 4'b0001;
end else if (busy) begin
case(state[3:2])
2'b00: begin
pe_rst <= 0;
C_wr_en <= 0;
AA[state[1]] <= k < K ? A_data_out : 0;
BB[state[1]] <= k < K ? B_data_out : 0;
a1k <= k < K ? $signed(A_data_out[15:8]) + $signed(offset)
: 0;
a2k <= $signed(AA[state[0]][7:0]) + $signed(offset);
bk1 <= k < K ? B_data_out[15:8] : 0;
bk2 <= BB[state[0]][7:0];
A_addr <= m*K + k + 1;
B_addr <= n*K + k + 1;
k <= k + 1;

if (k + 1 < K + 2) begin
state <= {2'b00, state[0], state[1]};
end else begin
state <= 4'b0101
end
end
2'b01: begin
C_wr_en <= 1;
C_addr <= n*M + m*2 + state[1];
C_data_in <= CC[state[1]];

if (m*2 + state[0] >= M || state[1] == 1'b1) begin
pe_rst <= 1;
AA[0] <= 0; AA[1] <= 0;
BB[0] <= 0; BB[1] <= 0;
A_addr <= n + 1 < cnt_N ? m * K : (m + 1) * K;
B_addr <= n + 1 < cnt_N ? (n + 1) * K : 0;
k <= 0;
n <= n + 1 < cnt_N ? n + 1 : 0;
m <= n + 1 < cnt_N ? m : m + 1;
state <= n + 1 == cnt_N && m + 1 == cnt_M ? 4'b1100
: 4'b0001;
end else begin
state <= {2'b01, state[0], state[1]};
end
end
2'b11: begin
C_wr_en <= 0;
busy <= 0;
end
endcase
end
end

--

--

Yeecy

A Ph.D. student at NYCU CS and a compiler engineer at ICEshell Co., Ltd. You can find more information about me on my GitHub page github.com/ADNRs.