Skip to content

educational app to showcase numerics behind transformer models

License

Notifications You must be signed in to change notification settings

ivfiev/tiny-transformers

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

tiny-transformers

Minimal from-scratch implementation of a basic single-block transformer. Trains using SPSA locally, no external dependencies nor API calls.

Usage

Configure an example (model learns a simple aa|bb|aa|bb|.... sequence in this case):

	t := trainModel(ctx, []rune("aa|bb|aa|bb|aa|bb|"), 4, seed, 5000, 0.01, 0.0001)

Run in the terminal:

$ go run .
Model 0 has loss 0.1352
Model 1 has loss 0.0848
Model 2 has loss 0.1096
Model 3 has loss 0.1206
The winner is 1! (828 parameters, d_model 11, ctx 8, vocab 3, 0.0848 loss)
Seed used 1769723094689693640

Trained in 0.591 seconds
Enter context, up to 8 chars: aa|b

Full breakdown:
Input embeddings (xs)
[ 1.000  0.000  0.000  1.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 1.000  0.000  0.000  0.000  1.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  1.000  0.000  0.000  0.000  1.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  1.000  0.000  0.000  0.000  1.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]

First LayerNorm Gamma & Beta:
[ 1.419  0.851  1.668  1.085  1.197  0.888  1.118  1.106  1.002  1.047  1.110 ]
[ 0.096 -0.261 -0.122  0.183  0.029 -0.330  0.091 -0.029  0.120  0.255 -0.007 ]

Queries
[-0.210  0.288  0.406 -0.102 -0.046 -0.193 -0.320  0.325  0.569  0.492 -0.728 ]
[ 0.344 -0.172 -0.604 -0.363 -0.311  0.075 -0.259 -0.276 -0.169 -0.347  0.413 ]
[-0.258 -0.257  0.195 -0.352 -0.385  0.764  0.130  0.270 -0.241 -0.149 -0.211 ]
[ 0.160 -0.199  0.063 -0.220 -0.112  0.166 -0.813  0.417  0.594 -0.516 -0.012 ]
[-0.125 -0.319 -0.290 -0.129  0.013 -0.003 -0.684 -0.169 -0.118  0.503 -0.558 ]
[ 0.295  0.188  0.247  0.256 -0.594  0.500 -0.019 -0.159  0.300  0.724 -0.494 ]
[-0.285  0.125  0.019 -0.497 -0.423  0.537  0.819 -0.469 -0.181 -0.515  0.114 ]
[-0.021  0.096 -0.336  0.568  0.153  0.460  0.176  0.401  0.040  0.002 -0.398 ]

Q (xs * queries)
[-1.266  1.126 -2.024  0.005  0.369  1.227 -2.365  1.061 ]
[-1.121  1.182 -2.227  0.276  0.771 -1.338 -2.282 -0.061 ]
[-0.015  0.675  1.110 -0.021  0.481  0.987  1.591  0.811 ]
[ 0.618 -2.478  1.139 -2.049 -2.042  0.435  2.532 -1.405 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]

Keys
[ 0.257  0.180 -0.779 -0.650  0.566 -0.111 -0.036  0.258  0.407  0.252  0.014 ]
[ 0.647  0.093  0.547  0.001  0.125 -0.153  0.262 -0.030  0.234  0.091 -0.349 ]
[-0.506 -0.050 -0.085  0.463 -0.634  0.029 -0.247  0.175  0.031  0.171 -0.201 ]
[-0.349  0.489 -0.447  0.232 -0.388  0.369  0.029 -0.110 -0.418  0.230  0.239 ]
[ 0.479  0.156 -0.304  0.343 -0.359  0.407  0.186  0.174  0.094 -0.403  0.473 ]
[ 0.084  0.271  0.119 -0.300  0.222  0.326 -0.323  0.068 -0.394 -0.439 -0.489 ]
[-0.769  0.162  0.755 -0.597 -0.238  0.161 -0.110 -0.164  0.211  0.142  0.395 ]
[-0.252  0.205 -0.076 -0.482  0.012  0.684  0.527  0.124  0.429 -0.061 -0.264 ]

K (xs * keys)
[-0.784  1.482  0.075 -0.482  2.037 -0.530 -4.835 -2.861 ]
[ 2.801  1.868 -3.195 -2.338 -0.044  1.001 -3.895 -1.468 ]
[ 0.238 -1.044  0.589  2.080  0.589  1.354  0.399  1.451 ]
[-3.377  2.227 -0.451 -1.700 -1.467 -0.414  2.616  0.619 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]

QK
[ 3.319  4.186 -0.056  0.333  0.000  0.000  0.000  0.000 ]
[ 4.770  3.962 -1.386  0.111  0.000  0.000  0.000  0.000 ]
[-2.549 -2.624  1.005  1.399  0.000  0.000  0.000  0.000 ]
[-4.733 -2.720 -0.753  1.186  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]

S (triangular softmax QK)
[ 1.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.692  0.308  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.027  0.025  0.948  0.000  0.000  0.000  0.000  0.000 ]
[ 0.002  0.017  0.123  0.857  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]

Values
[ 0.267  0.922  0.351 -0.629 -0.087 -0.345  0.536 -0.166 ]
[-0.062  0.169 -0.817 -0.348 -0.232  0.051 -0.416  0.536 ]
[-0.660 -0.640  0.475  0.563  0.196 -0.416 -0.278 -0.037 ]
[ 0.098  0.750  0.526 -0.203  0.383 -0.457  0.263 -0.280 ]
[ 0.520  0.008 -0.252  0.129  0.029 -0.007 -0.192 -0.628 ]
[-0.370  0.406 -0.185 -0.563  0.061 -0.282 -0.472 -0.429 ]
[ 0.490 -0.414  0.093  0.138 -0.394 -0.506 -0.101  0.220 ]
[ 0.484 -0.138 -0.726  0.158  0.187  0.394 -0.093  0.086 ]
[-0.133 -0.004 -0.206 -0.285 -0.202  0.285 -0.278 -0.325 ]
[ 0.007  0.224 -0.283  0.185 -0.348 -0.002 -0.078  0.216 ]
[-0.148 -0.031  0.398  0.584  0.557  0.571  0.097 -0.036 ]

V (softmax * values)
[ 0.267 -0.062 -0.660  0.098  0.520 -0.370  0.490  0.484 -0.133  0.007 -0.148 ]
[ 0.469  0.009 -0.654  0.299  0.362 -0.131  0.211  0.292 -0.093  0.074 -0.112 ]
[ 0.363 -0.772  0.416  0.520 -0.224 -0.175  0.091 -0.679 -0.199 -0.262  0.373 ]
[-0.479 -0.396  0.529 -0.096  0.081 -0.499  0.124  0.045 -0.270  0.128  0.549 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]

Second LayerNorm Gamma & Beta:
[ 1.770  1.425  1.176  1.022  1.347  1.340  0.922  1.106  1.251  1.253  1.094 ]
[-0.192 -0.027  0.020 -0.093 -0.072  0.143  0.058 -0.076  0.248 -0.189 -0.197 ]

MLP Input layer
[-0.467 -0.155  0.333  0.114 -0.260 -0.352  0.464  0.097 -0.305 -0.477 -0.303 ]
[-0.380  0.412  0.593  0.049  0.715  0.024  0.327  0.676 -0.545  0.142  0.043 ]
[-0.463  0.379  0.354 -0.151 -0.117 -0.318 -0.512 -0.223  0.399  0.342  0.062 ]
[ 0.245  0.441 -0.184 -0.020  0.049  0.336  0.134 -0.021  0.280  0.250 -0.424 ]
[ 0.167  0.120  0.149  0.205 -0.425  0.348 -0.129 -0.441  0.555 -0.544  0.200 ]
[-0.310 -0.411 -0.359  0.149  0.532  0.864  0.021  0.257 -0.396 -0.052 -0.062 ]
[-0.353 -0.201 -0.095 -0.209  0.221  0.460 -0.124 -0.443  0.196 -0.143  0.156 ]
[-0.187  0.497  0.282 -0.206 -0.330  0.507  0.538  0.403  0.192 -0.120 -0.453 ]
[-0.291  0.420 -0.039  0.481 -0.559  0.353 -0.462 -0.761 -0.413 -0.223  0.038 ]
[-0.757  0.427 -0.091 -0.690  0.209 -0.625 -0.247  0.186 -0.089 -0.461 -0.609 ]
[-0.293  0.111  0.086 -0.556  0.500 -0.226 -0.221 -0.320 -0.153 -0.005 -0.205 ]
[ 0.884 -0.478 -0.536  0.124  0.195  0.421 -0.222  0.636 -0.062  0.056  0.454 ]
[ 0.659  0.236  0.228 -0.436  0.100 -0.391  0.461 -0.385  0.191 -0.230  0.831 ]
[ 0.504  0.279 -0.587 -0.330 -0.084  0.214 -0.444  0.172  0.070 -0.016 -0.395 ]
[ 0.588  0.307 -0.225 -0.170  0.434 -0.110  0.436 -0.279 -0.168 -0.116  0.119 ]
[ 0.153 -0.144 -0.112  0.213  0.209 -0.339  0.300 -0.766  0.072  0.315  0.081 ]
[-0.236 -0.343  0.488  0.409  0.617 -0.203  0.350 -0.192 -0.024  0.097 -0.584 ]
[-0.283 -0.549 -0.464  0.071  0.739 -0.446  0.176 -0.124 -0.256  0.278  0.035 ]
[ 0.577 -0.925  0.157 -0.430  0.324 -0.386  0.339 -0.442 -0.003  0.320 -0.053 ]
[-0.442  0.194  0.508 -0.077 -0.231  0.420  0.360 -0.038 -0.385  0.612  0.376 ]
[-0.130  0.145 -0.316 -0.577 -0.144 -0.132  0.190  0.201 -0.168 -0.145  0.096 ]
[-0.199 -0.626  0.234  0.279 -0.283  0.631 -0.285  0.096  0.500 -0.339  0.152 ]

MLP Hidden layer
[-0.116 -0.029 -0.204 -0.637 -0.101 -0.265 -0.080  0.443 -0.046  0.123  0.185 -0.817 -0.593  0.239  0.608  0.605 -0.171  0.583 -0.245  0.212  0.355  0.466 ]
[-0.142 -0.248  0.394  0.687  0.092 -0.563  0.396  0.249 -0.090  0.129  0.196  0.206  0.572  0.434  0.418  0.339 -0.381 -0.807 -0.551 -0.330  0.077 -0.613 ]
[-0.200  0.046  0.122 -0.190  0.086 -0.239 -0.075 -0.235  0.086 -0.086  0.150  0.107 -0.001 -0.251 -0.022  0.555  0.397  0.122  0.709  0.413 -0.144  0.007 ]
[-0.430 -0.501 -0.018  0.453  0.563 -0.021 -0.298  0.019 -0.046 -0.546 -0.063  0.352 -0.453 -0.269  0.233  0.129 -0.126  0.088 -0.412  0.304  0.283 -0.579 ]
[ 0.203 -0.217 -0.166  0.485 -0.056  0.299 -0.334 -0.159 -0.496 -0.641 -0.456 -0.019  0.208 -0.060 -0.103 -0.074  0.083  0.038  0.333 -0.222  0.356  0.089 ]
[ 0.546  0.518  0.294 -0.138  0.349 -0.067 -0.371  0.341 -0.582 -0.449  0.547  0.103 -0.297 -0.083 -0.593 -0.337 -0.353  0.229 -0.323  0.623  0.273 -0.247 ]
[ 0.305 -0.031  0.609 -0.079 -0.164 -0.369  0.284 -0.527  0.034  0.263  0.345 -0.465  0.298  0.070 -0.395  0.011  0.487  0.471 -0.001  0.695 -0.285  0.386 ]
[-0.110 -0.216 -0.093  0.140 -0.702 -0.135 -0.167 -0.120  0.364  0.013 -0.581  0.688 -0.258  0.441 -0.496  0.233 -0.035 -0.472 -0.636 -0.033 -0.250 -0.724 ]
[-0.214  0.600  0.425 -0.413  0.149  0.066 -0.100 -0.052 -0.440  0.299  0.564  0.201 -0.222 -0.298  0.317  0.222  0.250  0.373  0.053  0.225  0.056 -0.381 ]
[-0.014  0.496  0.043 -0.327  0.200 -0.220  0.418  0.287  0.007  0.116  0.238 -0.152  0.042 -0.786  0.120  0.186  0.195  0.230  0.540 -0.501 -0.418  0.365 ]
[ 0.391  0.392 -0.051 -0.012  0.801 -0.149 -0.067 -0.426  0.759  0.744  0.608 -0.373 -0.759 -0.353  0.087 -0.539  0.283 -0.140 -0.948  0.439  0.003 -0.240 ]

Linear
[ 0.239  0.271  0.532 ]
[ 1.243 -0.713 -0.397 ]
[-0.493  0.565 -0.249 ]
[ 0.383 -0.504 -0.201 ]
[-0.195  0.647 -0.911 ]
[-0.628  0.174  0.301 ]
[-0.813 -0.158  0.252 ]
[ 0.675 -0.162 -0.705 ]
[-0.009  0.023  0.282 ]
[-0.432  0.349  0.003 ]
[-0.068 -1.026  1.120 ]

Bias
[ 0.443 -0.362  0.360 ]

Logits (V * Linear + Bias)
[ 4.559  4.574 -8.249 ]
[ 3.451  8.749 -10.957 ]
[-2.658 -6.010 10.142 ]
[-17.608  3.578 10.739 ]
[ 0.443 -0.362  0.360 ]
[ 0.443 -0.362  0.360 ]
[ 0.443 -0.362  0.360 ]
[ 0.443 -0.362  0.360 ]

--------------------------------------
Detailed breakdown for the last token:
Last token's embedding:
[ 0.000  0.000  1.000  0.000  0.000  0.000  1.000  0.000  0.000  0.000  0.000 ]

After first LayerNorm:
[-0.573 -0.662  3.416 -0.328 -0.536 -0.749  2.463 -0.550 -0.352 -0.238 -0.531 ]

Token's Query:
[ 0.618 -2.478  1.139 -2.049 -2.042  0.435  2.532 -1.405 ]

Available Keys:
[-0.784  1.482  0.075 -0.482  2.037 -0.530 -4.835 -2.861 ]
[ 2.801  1.868 -3.195 -2.338 -0.044  1.001 -3.895 -1.468 ]
[ 0.238 -1.044  0.589  2.080  0.589  1.354  0.399  1.451 ]
[-3.377  2.227 -0.451 -1.700 -1.467 -0.414  2.616  0.619 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]
[ 0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 ]

Raw scores against Keys (QKT):
[-4.733 -2.720 -0.753  1.186  0.000  0.000  0.000  0.000 ]

Normalized Softmax scores:
[ 0.002  0.017  0.123  0.857  0.000  0.000  0.000  0.000 ]

Dot product with Value rows:
[ 0.267  0.922  0.351 -0.629 -0.087 -0.345  0.536 -0.166 ]
[-0.062  0.169 -0.817 -0.348 -0.232  0.051 -0.416  0.536 ]
[-0.660 -0.640  0.475  0.563  0.196 -0.416 -0.278 -0.037 ]
[ 0.098  0.750  0.526 -0.203  0.383 -0.457  0.263 -0.280 ]
[ 0.520  0.008 -0.252  0.129  0.029 -0.007 -0.192 -0.628 ]
[-0.370  0.406 -0.185 -0.563  0.061 -0.282 -0.472 -0.429 ]
[ 0.490 -0.414  0.093  0.138 -0.394 -0.506 -0.101  0.220 ]
[ 0.484 -0.138 -0.726  0.158  0.187  0.394 -0.093  0.086 ]
[-0.133 -0.004 -0.206 -0.285 -0.202  0.285 -0.278 -0.325 ]
[ 0.007  0.224 -0.283  0.185 -0.348 -0.002 -0.078  0.216 ]
[-0.148 -0.031  0.398  0.584  0.557  0.571  0.097 -0.036 ]

To get the final Value:
[-0.479 -0.396  0.529 -0.096  0.081 -0.499  0.124  0.045 -0.270  0.128  0.549 ]

Residual stream:
[ 0.000  0.000  1.000  0.000  0.000  0.000  1.000  0.000  0.000  0.000  0.000 ]
+
[-0.479 -0.396  0.529 -0.096  0.081 -0.499  0.124  0.045 -0.270  0.128  0.549 ]
=
[-0.479 -0.396  1.529 -0.096  0.081 -0.499  1.124  0.045 -0.270  0.128  0.549 ]

After second LayerNorm:
[-1.973 -1.275  2.580 -0.501 -0.233 -1.247  1.473 -0.270 -0.597 -0.245  0.484 ]

Pass through Input layer:
[ 3.231  2.143  0.857 -2.172 -0.708 -0.886  0.117  0.391 -0.766  1.247  0.857 -3.406  0.797 -3.875 -1.107  0.427  2.341  0.806  1.583  2.306  0.153  0.345 ]

Activation:
[ 3.231  2.143  0.857  0.000  0.000  0.000  0.117  0.391  0.000  1.247  0.857  0.000  0.797  0.000  0.000  0.427  2.341  0.806  1.583  2.306  0.153  0.345 ]

Pass through Hidden layer:
[ 0.037 -2.954  2.795 -3.880 -0.782  2.984  5.344 -3.385  3.066  2.084  2.478 ]

Residual stream:
[-0.479 -0.396  1.529 -0.096  0.081 -0.499  1.124  0.045 -0.270  0.128  0.549 ]
+
[ 0.037 -2.954  2.795 -3.880 -0.782  2.984  5.344 -3.385  3.066  2.084  2.478 ]
=
[-0.442 -3.351  4.324 -3.976 -0.702  2.485  6.468 -3.340  2.796  2.211  3.026 ]

Dot product with Linear layer rows:
[ 0.239  0.271  0.532 ]
[ 1.243 -0.713 -0.397 ]
[-0.493  0.565 -0.249 ]
[ 0.383 -0.504 -0.201 ]
[-0.195  0.647 -0.911 ]
[-0.628  0.174  0.301 ]
[-0.813 -0.158  0.252 ]
[ 0.675 -0.162 -0.705 ]
[-0.009  0.023  0.282 ]
[-0.432  0.349  0.003 ]
[-0.068 -1.026  1.120 ]

And add Bias:
[ 0.443 -0.362  0.360 ]

To get the final Logits:
[-17.608  3.578 10.739 ]

Input: [aa|b]
Next token probabilities:
[a] -> 0.000000
[|] -> 0.000776
[b] -> 0.999224

About

educational app to showcase numerics behind transformer models

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages