ChatGLM3典型计算图
## data flow
```
query -> "你好"
|
tokenizer -> input_ids [6]
|
rotary_pos_emb embedding -> [1, 6, 4096]
\ /
GLMBlock x 28 -> [6, 1, 4096] <---|
RMSNorm -> [6, 1, 4096] | final_layernorm
[-1:] -> [1, 1, 4096] |
Linear -> [1, 1, 65024] | output_layer 4096->65024
softmax -> [1, 65024] |
multinomial -> [1] |
cat([input_ids, next_tokens]) ---|
↓
tokenizer.decode( )
# GLMBlock
input
/ \
/ RMSNorm hidden_states -> [6, 1, 4096]
| | / \
| | | pow(2) -> [6, 1, 4096]
| | | |
| | | mean -> [6, 1, 1]
| | | ↓
| | | rsqrt( + eps) -> [6, 1, 1]
| | \ /
| | mul -> [6, 1, 4096]
| | \ weight -> [4096]
| | \ /
| RMSNorm mul -> [6, 1, 4096]
| \
| SelfAttention x -> [6, 1, 4096]
| | |
| | Linear -> [6, 1, 4608] 4096->4608
| | / | \
| | q k v [6, 1, 32, 128] [6, 1, 2, 128] [6, 1, 2, 128]
| | / | \
| | pos_emb pos_emb \ -> cat( x0*y0-x1*y1, x1*y0-x0*y1, x, y)
| | | | |
| | | expand expand -> [6, 1, 32, 128] [6, 1, 32, 128]
| | permute permute permute -> [1, 32, 6, 128] [1, 32, 6, 128] [1, 32, 6, 128]
| | \ / |
| | |---- matmul | -> [1, 32, 6, 128] [1, 32, 128, 6] -> [1, 32, 6, 6]
| | | add(mask) / -> [1, 32, 6, 6]
| | attention| softmax / -> [1, 32, 6, 6] dim:-1
| | | \ /
| | |---- matmul -> [1, 32, 6, 6] [1, 32, 6, 128] -> [1, 32, 6, 128] -> [6, 1, 4096]
| SelfAttention Linear -> [6, 1, 4096] 4096->4096
| /
| dropout
\ /
Add
/ \
| RMSNorm hidden_states -> [6, 1, 4096]
| | / \
| | | pow(2) -> [6, 1, 4096]
| | | |
| | | mean -> [6, 1, 1]
| | | ↓
| | | rsqrt( + eps) -> [6, 1, 1]
| | \ /
| | mul -> [6, 1, 4096]
| | \ weight -> [4096]
| | \ /
| RMSNorm mul -> [6, 1, 4096]
| /
| mlp /
| | Linear -> [6, 1, 27392] 4096->27392
| | / \
| | chunk1 chunk0 -> [6, 1, 13696]
| | | | \
| | | | sigmoid
| | | | /
| | | mul
| | \ /
| | mul -> [6, 1, 13696]
| mlp Linear -> [6, 1, 4096] 13696->4096
| /
| dropout
| /
Add
```
No Comments