摘要

针对脉动阵列 systolic 做了一些小改进,利用缓冲器实现矩阵乘法期的复用。

模块设计

原理

以下是systolic的矩阵计算原理,主要是控制每个单元计算所需的每行的数据流在同一时刻流过,利用加法器和流水乘法器实现矩阵乘法。此外根据数据流入的两个方向,可以计算出systolic阵列的数据流实际的速度为红色箭头所示。可以计算出完成时间的关系
图片1

扩充

在此基础上做了一点小扩充:

图片2
在此基础上思考输入与输出的格式对齐方案:我使用了一种简单的设计,在计算完成时刻对整体的数据向下或右移出,这样不用添加额外的寄存器或读取转换电路,只需要将向右侧流动的数据位宽改为结果位宽即可,这样可以实现阵列的复用。

可根据需求另外设计:
1.根据左右乘矩阵,修改generate中的数据线横向纵向的连接方式即可

2.根据外部数据输出格式:修改输入输出的矩阵行列元素顺序:顺序输入–顺序输出、逆序输入–逆序输出
图片3
图片4
图片5

代码编写

代码如下:

最小单元设计

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
module pulse_arrays_pe  #(
parameter WIDTH_left = 8,
parameter WIDTH_up = 8,
parameter WIDTH_out = 8
)(
input wire clk,
input wire rst,
input wire [1:0]mode, // mode=0 tpu_computer mode=1,2 shift out_data

input wire [WIDTH_out-1:0] left,
input wire [WIDTH_up-1:0] up,

output reg [WIDTH_out-1:0] right,
output reg [WIDTH_up-1:0] down,
output reg [WIDTH_out-1:0] out_data
);
//reg [WIDTH_out-1:0] out_data;
wire [WIDTH_out-1:0] temp_data;

always@(posedge clk)begin
if(!rst)begin
right <= 0;
down <= 0;
out_data <= 0;
end
else begin
//computer
if(mode==0)begin
right<= left;//{{(WIDTH_out-WIDTH_left){1'd0}},left}
down <= up;
out_data <= temp_data+out_data;
end
//shift out_data
else if(mode==2)begin
right <= out_data;
out_data <= 0;
end

else
right <= left;

end
end

//multiply module instantiation
FIX_unsigned_MUL #(.WIDTH_multiplicand(WIDTH_left), .WIDTH_multiplier(WIDTH_up))
uut
(
.clk (clk),
.rst (rst),
.valid (1'd1),
.multiplicand (left[WIDTH_left-1:0]),
.multiplier (up),
.ready (),
.product (temp_data)
);

endmodule

通用阵列:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

module pulse_arrays #(
parameter WIDTH_left = 8,
parameter WIDTH_up = 8,
parameter WIDTH_out = 8,
parameter Mritx_M = 3,//output row
parameter Mritx_N = 3,//
parameter Mritx_L = 3,//output col
parameter Mritx_LOG2_size = 10 //counter's width
)(
input clk,
input rst,

input wire valid_left,
input wire valid_up,
input wire [Mritx_M*WIDTH_left-1:0] left,
input wire [Mritx_L*WIDTH_up-1:0] up,

output reg ready, //ready for input
output wire [WIDTH_out*Mritx_M-1:0] product
);
localparam idl = 4'd0;
localparam state_in = 4'd1;
localparam state_out= 4'd2;
// state register
reg [3:0] state;


wire [WIDTH_out-1:0] left_temp [Mritx_M*Mritx_L-1:0]; //x_unite_wire
wire [WIDTH_up-1:0] up_temp [Mritx_M*Mritx_L-1:0]; //y_unite_wire

//enable signal

// reg [Mritx_M-1:0] left_shift_en;
// reg [Mritx_L-1:0] up_shift_en;
reg star=0,export=0,finish=0;//flag
reg [Mritx_LOG2_size-1:0] cnt_flow1;
reg [Mritx_LOG2_size-1:0] cnt_flow2;

reg [2*Mritx_L-1:0] mode_control=0;// 行同步控制 初始赋值0 大量变1易串扰不稳定
//output
wire [WIDTH_out-1:0] out_data[Mritx_M*Mritx_L-1:0];


assign left_temp[0] = valid_left?{{(WIDTH_out-WIDTH_left){1'd0}},left[WIDTH_left-1:0]}:0;
assign up_temp[0] = valid_up?up[WIDTH_up-1:0]:0;

always @(posedge clk ) begin
if(!rst)begin
state <= idl;
star <= 0;
export<= 0;
ready <= 0;
finish<= 0;
cnt_flow1 <= 0;
cnt_flow2 <= 0;
mode_control <= 0;
end
else begin
case (state)
idl :begin
ready <= 1;
star <= 0;
export <= 0;
finish <= 0;
cnt_flow1 <= 0;
cnt_flow2 <= 0;
mode_control <= 0;
if(valid_left&valid_up)
state <= state_in;
else
state <= idl;
end
state_in :begin
ready <= 0;
star <= 1;
export <= 0;
finish <= 0;
cnt_flow1 <= cnt_flow1 + 1;
cnt_flow2 <= 0;
if(cnt_flow1==Mritx_M+Mritx_N+Mritx_L-2+WIDTH_up-1)begin
state <= state_out;
mode_control <= {Mritx_L{2'd2}};
end
else
state <= state_in;
end
state_out:begin
ready <= 0;
star <= 0;
cnt_flow1 <= 0;
cnt_flow2 <= cnt_flow2 + 1;
if(cnt_flow2==0)
mode_control <= {Mritx_L{2'd1}}<<2;
else
mode_control <= mode_control<<2;
if(cnt_flow2==Mritx_L) begin
state <= idl;
finish<= 1;
export<= 0;
end
else begin
state <= state_out;
finish<= 0;
export<= 1;
end
end

default: begin
ready <= 0;
star <= 0;
export <= 0;
finish <= 0;
cnt_flow1 <= 0;
cnt_flow2 <= 0;
mode_control <= 0;
state <= idl;
end
endcase
end
end

generate
genvar i,j;
for(i=0;i<=Mritx_M-1;i=i+1)begin
for(j=0;j<=Mritx_L-1;j=j+1)begin:pulse_arrays_pex
if(i==0&&j==0)begin
pulse_arrays_pe #(
.WIDTH_left (WIDTH_left),
.WIDTH_up (WIDTH_up),
.WIDTH_out (WIDTH_out)
)pulse_arrays_pex(
.clk (clk),
.rst (rst),
.mode (mode_control[(j+1)*2-1:j*2]),
.left (left_temp[0]),
.up (up_temp[0]),
.right (left_temp[i*Mritx_L+j+1]),
.down (up_temp[(i+1)*Mritx_L+j]),
.out_data (out_data[i*Mritx_L+j])
);
end

else if(i==Mritx_M-1&&j==Mritx_L-1)begin
pulse_arrays_pe #(
.WIDTH_left (WIDTH_left),
.WIDTH_up (WIDTH_up),
.WIDTH_out (WIDTH_out)
)pulse_arrays_pex(
.clk (clk),
.rst (rst),
.mode (mode_control[(j+1)*2-1:j*2]),
.left (left_temp[i*Mritx_L+j]),
.up (up_temp[i*Mritx_L+j]),
.right (product[WIDTH_out*(i+1)-1:WIDTH_out*i]),
.down (),
.out_data (out_data[i*Mritx_L+j])
);
end
else if(i==Mritx_M-1)begin
pulse_arrays_pe #(
.WIDTH_left (WIDTH_left),
.WIDTH_up (WIDTH_up),
.WIDTH_out (WIDTH_out)
)pulse_arrays_pex(
.clk (clk),
.rst (rst),
.mode (mode_control[(j+1)*2-1:j*2]),
.left (left_temp[i*Mritx_L+j]),
.up (up_temp[i*Mritx_L+j]),
.right (left_temp[i*Mritx_L+j+1]),
.down (),
.out_data (out_data[i*Mritx_L+j])
);
end
else if(j==Mritx_L-1)begin
pulse_arrays_pe #(
.WIDTH_left (WIDTH_left),
.WIDTH_up (WIDTH_up),
.WIDTH_out (WIDTH_out)
)pulse_arrays_pex(
.clk (clk),
.rst (rst),
.mode (mode_control[(j+1)*2-1:j*2]),
.left (left_temp[i*Mritx_L+j]),
.up (up_temp[i*Mritx_L+j]),
.right (product[WIDTH_out*(i+1)-1:WIDTH_out*i]),
.down (up_temp[(i+1)*Mritx_L+j]),
.out_data (out_data[i*Mritx_L+j])
);
end
else begin
pulse_arrays_pe #(
.WIDTH_left (WIDTH_left),
.WIDTH_up (WIDTH_up),
.WIDTH_out (WIDTH_out)
)pulse_arrays_pex(
.clk (clk),
.rst (rst),
.mode (mode_control[(j+1)*2-1:j*2]),
.left (left_temp[i*Mritx_L+j]),
.up (up_temp[i*Mritx_L+j]),
.right (left_temp[i*Mritx_L+j+1]),
.down (up_temp[(i+1)*Mritx_L+j]),
.out_data (out_data[i*Mritx_L+j])
);
end
end
end
endgenerate

generate
genvar m,n;
for(m=1;m<Mritx_M;m=m+1)begin:shift_register_left
shift_register #(
.WIDTH_in(WIDTH_left),
.WIDTH_out(WIDTH_out),
.DEEP(m),
.PTR_SIZE(Mritx_LOG2_size)
)shift_register_left(
.clk(clk),
.rst(rst),
.shift_en(valid_left),
.shift_in(left[(m+1)*WIDTH_left-1:(m)*WIDTH_left]),
.shift_out(left_temp[m*Mritx_L])
);
end
for(n=1;n<Mritx_L;n=n+1)begin:shift_register_up
shift_register #(
.WIDTH_in(WIDTH_up),
.WIDTH_out(WIDTH_up),
.DEEP(n),
.PTR_SIZE(Mritx_LOG2_size)
)shift_register_up(
.clk(clk),
.rst(rst),
.shift_en(valid_up),
.shift_in(up[(n+1)*WIDTH_up-1:n*WIDTH_up]),
.shift_out(up_temp[n])
);
end
endgenerate

endmodule

总结

要想在这个架构上实现真正的流水线矩阵乘法器,需要按照第一张图中的单元计算完成梯度图来将数据读出,但我目前并没有找到在不消耗大量硬件资源的情况下可通用的方法,后续找到的话会继续更新。

More info: CSND Github