Minimal from-scratch implementation of a basic single-block transformer. Trains using SPSA locally, no external dependencies nor API calls.
Configure an example (model learns a simple aa|bb|aa|bb|.... sequence in this case):
t := trainModel(ctx, []rune("aa|bb|aa|bb|aa|bb|"), 4, seed, 5000, 0.01, 0.0001)Run in the terminal:
$ go run .
Model 0 has loss 0.1352
Model 1 has loss 0.0848
Model 2 has loss 0.1096
Model 3 has loss 0.1206
The winner is 1! (828 parameters, d_model 11, ctx 8, vocab 3, 0.0848 loss)
Seed used 1769723094689693640
Trained in 0.591 seconds
Enter context, up to 8 chars: aa|b
Full breakdown:
Input embeddings (xs)
[ 1.000 0.000 0.000 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 1.000 0.000 0.000 0.000 1.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 1.000 0.000 0.000 0.000 1.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 1.000 0.000 0.000 0.000 1.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
First LayerNorm Gamma & Beta:
[ 1.419 0.851 1.668 1.085 1.197 0.888 1.118 1.106 1.002 1.047 1.110 ]
[ 0.096 -0.261 -0.122 0.183 0.029 -0.330 0.091 -0.029 0.120 0.255 -0.007 ]
Queries
[-0.210 0.288 0.406 -0.102 -0.046 -0.193 -0.320 0.325 0.569 0.492 -0.728 ]
[ 0.344 -0.172 -0.604 -0.363 -0.311 0.075 -0.259 -0.276 -0.169 -0.347 0.413 ]
[-0.258 -0.257 0.195 -0.352 -0.385 0.764 0.130 0.270 -0.241 -0.149 -0.211 ]
[ 0.160 -0.199 0.063 -0.220 -0.112 0.166 -0.813 0.417 0.594 -0.516 -0.012 ]
[-0.125 -0.319 -0.290 -0.129 0.013 -0.003 -0.684 -0.169 -0.118 0.503 -0.558 ]
[ 0.295 0.188 0.247 0.256 -0.594 0.500 -0.019 -0.159 0.300 0.724 -0.494 ]
[-0.285 0.125 0.019 -0.497 -0.423 0.537 0.819 -0.469 -0.181 -0.515 0.114 ]
[-0.021 0.096 -0.336 0.568 0.153 0.460 0.176 0.401 0.040 0.002 -0.398 ]
Q (xs * queries)
[-1.266 1.126 -2.024 0.005 0.369 1.227 -2.365 1.061 ]
[-1.121 1.182 -2.227 0.276 0.771 -1.338 -2.282 -0.061 ]
[-0.015 0.675 1.110 -0.021 0.481 0.987 1.591 0.811 ]
[ 0.618 -2.478 1.139 -2.049 -2.042 0.435 2.532 -1.405 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
Keys
[ 0.257 0.180 -0.779 -0.650 0.566 -0.111 -0.036 0.258 0.407 0.252 0.014 ]
[ 0.647 0.093 0.547 0.001 0.125 -0.153 0.262 -0.030 0.234 0.091 -0.349 ]
[-0.506 -0.050 -0.085 0.463 -0.634 0.029 -0.247 0.175 0.031 0.171 -0.201 ]
[-0.349 0.489 -0.447 0.232 -0.388 0.369 0.029 -0.110 -0.418 0.230 0.239 ]
[ 0.479 0.156 -0.304 0.343 -0.359 0.407 0.186 0.174 0.094 -0.403 0.473 ]
[ 0.084 0.271 0.119 -0.300 0.222 0.326 -0.323 0.068 -0.394 -0.439 -0.489 ]
[-0.769 0.162 0.755 -0.597 -0.238 0.161 -0.110 -0.164 0.211 0.142 0.395 ]
[-0.252 0.205 -0.076 -0.482 0.012 0.684 0.527 0.124 0.429 -0.061 -0.264 ]
K (xs * keys)
[-0.784 1.482 0.075 -0.482 2.037 -0.530 -4.835 -2.861 ]
[ 2.801 1.868 -3.195 -2.338 -0.044 1.001 -3.895 -1.468 ]
[ 0.238 -1.044 0.589 2.080 0.589 1.354 0.399 1.451 ]
[-3.377 2.227 -0.451 -1.700 -1.467 -0.414 2.616 0.619 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
QK
[ 3.319 4.186 -0.056 0.333 0.000 0.000 0.000 0.000 ]
[ 4.770 3.962 -1.386 0.111 0.000 0.000 0.000 0.000 ]
[-2.549 -2.624 1.005 1.399 0.000 0.000 0.000 0.000 ]
[-4.733 -2.720 -0.753 1.186 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
S (triangular softmax QK)
[ 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.692 0.308 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.027 0.025 0.948 0.000 0.000 0.000 0.000 0.000 ]
[ 0.002 0.017 0.123 0.857 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
Values
[ 0.267 0.922 0.351 -0.629 -0.087 -0.345 0.536 -0.166 ]
[-0.062 0.169 -0.817 -0.348 -0.232 0.051 -0.416 0.536 ]
[-0.660 -0.640 0.475 0.563 0.196 -0.416 -0.278 -0.037 ]
[ 0.098 0.750 0.526 -0.203 0.383 -0.457 0.263 -0.280 ]
[ 0.520 0.008 -0.252 0.129 0.029 -0.007 -0.192 -0.628 ]
[-0.370 0.406 -0.185 -0.563 0.061 -0.282 -0.472 -0.429 ]
[ 0.490 -0.414 0.093 0.138 -0.394 -0.506 -0.101 0.220 ]
[ 0.484 -0.138 -0.726 0.158 0.187 0.394 -0.093 0.086 ]
[-0.133 -0.004 -0.206 -0.285 -0.202 0.285 -0.278 -0.325 ]
[ 0.007 0.224 -0.283 0.185 -0.348 -0.002 -0.078 0.216 ]
[-0.148 -0.031 0.398 0.584 0.557 0.571 0.097 -0.036 ]
V (softmax * values)
[ 0.267 -0.062 -0.660 0.098 0.520 -0.370 0.490 0.484 -0.133 0.007 -0.148 ]
[ 0.469 0.009 -0.654 0.299 0.362 -0.131 0.211 0.292 -0.093 0.074 -0.112 ]
[ 0.363 -0.772 0.416 0.520 -0.224 -0.175 0.091 -0.679 -0.199 -0.262 0.373 ]
[-0.479 -0.396 0.529 -0.096 0.081 -0.499 0.124 0.045 -0.270 0.128 0.549 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
Second LayerNorm Gamma & Beta:
[ 1.770 1.425 1.176 1.022 1.347 1.340 0.922 1.106 1.251 1.253 1.094 ]
[-0.192 -0.027 0.020 -0.093 -0.072 0.143 0.058 -0.076 0.248 -0.189 -0.197 ]
MLP Input layer
[-0.467 -0.155 0.333 0.114 -0.260 -0.352 0.464 0.097 -0.305 -0.477 -0.303 ]
[-0.380 0.412 0.593 0.049 0.715 0.024 0.327 0.676 -0.545 0.142 0.043 ]
[-0.463 0.379 0.354 -0.151 -0.117 -0.318 -0.512 -0.223 0.399 0.342 0.062 ]
[ 0.245 0.441 -0.184 -0.020 0.049 0.336 0.134 -0.021 0.280 0.250 -0.424 ]
[ 0.167 0.120 0.149 0.205 -0.425 0.348 -0.129 -0.441 0.555 -0.544 0.200 ]
[-0.310 -0.411 -0.359 0.149 0.532 0.864 0.021 0.257 -0.396 -0.052 -0.062 ]
[-0.353 -0.201 -0.095 -0.209 0.221 0.460 -0.124 -0.443 0.196 -0.143 0.156 ]
[-0.187 0.497 0.282 -0.206 -0.330 0.507 0.538 0.403 0.192 -0.120 -0.453 ]
[-0.291 0.420 -0.039 0.481 -0.559 0.353 -0.462 -0.761 -0.413 -0.223 0.038 ]
[-0.757 0.427 -0.091 -0.690 0.209 -0.625 -0.247 0.186 -0.089 -0.461 -0.609 ]
[-0.293 0.111 0.086 -0.556 0.500 -0.226 -0.221 -0.320 -0.153 -0.005 -0.205 ]
[ 0.884 -0.478 -0.536 0.124 0.195 0.421 -0.222 0.636 -0.062 0.056 0.454 ]
[ 0.659 0.236 0.228 -0.436 0.100 -0.391 0.461 -0.385 0.191 -0.230 0.831 ]
[ 0.504 0.279 -0.587 -0.330 -0.084 0.214 -0.444 0.172 0.070 -0.016 -0.395 ]
[ 0.588 0.307 -0.225 -0.170 0.434 -0.110 0.436 -0.279 -0.168 -0.116 0.119 ]
[ 0.153 -0.144 -0.112 0.213 0.209 -0.339 0.300 -0.766 0.072 0.315 0.081 ]
[-0.236 -0.343 0.488 0.409 0.617 -0.203 0.350 -0.192 -0.024 0.097 -0.584 ]
[-0.283 -0.549 -0.464 0.071 0.739 -0.446 0.176 -0.124 -0.256 0.278 0.035 ]
[ 0.577 -0.925 0.157 -0.430 0.324 -0.386 0.339 -0.442 -0.003 0.320 -0.053 ]
[-0.442 0.194 0.508 -0.077 -0.231 0.420 0.360 -0.038 -0.385 0.612 0.376 ]
[-0.130 0.145 -0.316 -0.577 -0.144 -0.132 0.190 0.201 -0.168 -0.145 0.096 ]
[-0.199 -0.626 0.234 0.279 -0.283 0.631 -0.285 0.096 0.500 -0.339 0.152 ]
MLP Hidden layer
[-0.116 -0.029 -0.204 -0.637 -0.101 -0.265 -0.080 0.443 -0.046 0.123 0.185 -0.817 -0.593 0.239 0.608 0.605 -0.171 0.583 -0.245 0.212 0.355 0.466 ]
[-0.142 -0.248 0.394 0.687 0.092 -0.563 0.396 0.249 -0.090 0.129 0.196 0.206 0.572 0.434 0.418 0.339 -0.381 -0.807 -0.551 -0.330 0.077 -0.613 ]
[-0.200 0.046 0.122 -0.190 0.086 -0.239 -0.075 -0.235 0.086 -0.086 0.150 0.107 -0.001 -0.251 -0.022 0.555 0.397 0.122 0.709 0.413 -0.144 0.007 ]
[-0.430 -0.501 -0.018 0.453 0.563 -0.021 -0.298 0.019 -0.046 -0.546 -0.063 0.352 -0.453 -0.269 0.233 0.129 -0.126 0.088 -0.412 0.304 0.283 -0.579 ]
[ 0.203 -0.217 -0.166 0.485 -0.056 0.299 -0.334 -0.159 -0.496 -0.641 -0.456 -0.019 0.208 -0.060 -0.103 -0.074 0.083 0.038 0.333 -0.222 0.356 0.089 ]
[ 0.546 0.518 0.294 -0.138 0.349 -0.067 -0.371 0.341 -0.582 -0.449 0.547 0.103 -0.297 -0.083 -0.593 -0.337 -0.353 0.229 -0.323 0.623 0.273 -0.247 ]
[ 0.305 -0.031 0.609 -0.079 -0.164 -0.369 0.284 -0.527 0.034 0.263 0.345 -0.465 0.298 0.070 -0.395 0.011 0.487 0.471 -0.001 0.695 -0.285 0.386 ]
[-0.110 -0.216 -0.093 0.140 -0.702 -0.135 -0.167 -0.120 0.364 0.013 -0.581 0.688 -0.258 0.441 -0.496 0.233 -0.035 -0.472 -0.636 -0.033 -0.250 -0.724 ]
[-0.214 0.600 0.425 -0.413 0.149 0.066 -0.100 -0.052 -0.440 0.299 0.564 0.201 -0.222 -0.298 0.317 0.222 0.250 0.373 0.053 0.225 0.056 -0.381 ]
[-0.014 0.496 0.043 -0.327 0.200 -0.220 0.418 0.287 0.007 0.116 0.238 -0.152 0.042 -0.786 0.120 0.186 0.195 0.230 0.540 -0.501 -0.418 0.365 ]
[ 0.391 0.392 -0.051 -0.012 0.801 -0.149 -0.067 -0.426 0.759 0.744 0.608 -0.373 -0.759 -0.353 0.087 -0.539 0.283 -0.140 -0.948 0.439 0.003 -0.240 ]
Linear
[ 0.239 0.271 0.532 ]
[ 1.243 -0.713 -0.397 ]
[-0.493 0.565 -0.249 ]
[ 0.383 -0.504 -0.201 ]
[-0.195 0.647 -0.911 ]
[-0.628 0.174 0.301 ]
[-0.813 -0.158 0.252 ]
[ 0.675 -0.162 -0.705 ]
[-0.009 0.023 0.282 ]
[-0.432 0.349 0.003 ]
[-0.068 -1.026 1.120 ]
Bias
[ 0.443 -0.362 0.360 ]
Logits (V * Linear + Bias)
[ 4.559 4.574 -8.249 ]
[ 3.451 8.749 -10.957 ]
[-2.658 -6.010 10.142 ]
[-17.608 3.578 10.739 ]
[ 0.443 -0.362 0.360 ]
[ 0.443 -0.362 0.360 ]
[ 0.443 -0.362 0.360 ]
[ 0.443 -0.362 0.360 ]
--------------------------------------
Detailed breakdown for the last token:
Last token's embedding:
[ 0.000 0.000 1.000 0.000 0.000 0.000 1.000 0.000 0.000 0.000 0.000 ]
After first LayerNorm:
[-0.573 -0.662 3.416 -0.328 -0.536 -0.749 2.463 -0.550 -0.352 -0.238 -0.531 ]
Token's Query:
[ 0.618 -2.478 1.139 -2.049 -2.042 0.435 2.532 -1.405 ]
Available Keys:
[-0.784 1.482 0.075 -0.482 2.037 -0.530 -4.835 -2.861 ]
[ 2.801 1.868 -3.195 -2.338 -0.044 1.001 -3.895 -1.468 ]
[ 0.238 -1.044 0.589 2.080 0.589 1.354 0.399 1.451 ]
[-3.377 2.227 -0.451 -1.700 -1.467 -0.414 2.616 0.619 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
[ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 ]
Raw scores against Keys (QKT):
[-4.733 -2.720 -0.753 1.186 0.000 0.000 0.000 0.000 ]
Normalized Softmax scores:
[ 0.002 0.017 0.123 0.857 0.000 0.000 0.000 0.000 ]
Dot product with Value rows:
[ 0.267 0.922 0.351 -0.629 -0.087 -0.345 0.536 -0.166 ]
[-0.062 0.169 -0.817 -0.348 -0.232 0.051 -0.416 0.536 ]
[-0.660 -0.640 0.475 0.563 0.196 -0.416 -0.278 -0.037 ]
[ 0.098 0.750 0.526 -0.203 0.383 -0.457 0.263 -0.280 ]
[ 0.520 0.008 -0.252 0.129 0.029 -0.007 -0.192 -0.628 ]
[-0.370 0.406 -0.185 -0.563 0.061 -0.282 -0.472 -0.429 ]
[ 0.490 -0.414 0.093 0.138 -0.394 -0.506 -0.101 0.220 ]
[ 0.484 -0.138 -0.726 0.158 0.187 0.394 -0.093 0.086 ]
[-0.133 -0.004 -0.206 -0.285 -0.202 0.285 -0.278 -0.325 ]
[ 0.007 0.224 -0.283 0.185 -0.348 -0.002 -0.078 0.216 ]
[-0.148 -0.031 0.398 0.584 0.557 0.571 0.097 -0.036 ]
To get the final Value:
[-0.479 -0.396 0.529 -0.096 0.081 -0.499 0.124 0.045 -0.270 0.128 0.549 ]
Residual stream:
[ 0.000 0.000 1.000 0.000 0.000 0.000 1.000 0.000 0.000 0.000 0.000 ]
+
[-0.479 -0.396 0.529 -0.096 0.081 -0.499 0.124 0.045 -0.270 0.128 0.549 ]
=
[-0.479 -0.396 1.529 -0.096 0.081 -0.499 1.124 0.045 -0.270 0.128 0.549 ]
After second LayerNorm:
[-1.973 -1.275 2.580 -0.501 -0.233 -1.247 1.473 -0.270 -0.597 -0.245 0.484 ]
Pass through Input layer:
[ 3.231 2.143 0.857 -2.172 -0.708 -0.886 0.117 0.391 -0.766 1.247 0.857 -3.406 0.797 -3.875 -1.107 0.427 2.341 0.806 1.583 2.306 0.153 0.345 ]
Activation:
[ 3.231 2.143 0.857 0.000 0.000 0.000 0.117 0.391 0.000 1.247 0.857 0.000 0.797 0.000 0.000 0.427 2.341 0.806 1.583 2.306 0.153 0.345 ]
Pass through Hidden layer:
[ 0.037 -2.954 2.795 -3.880 -0.782 2.984 5.344 -3.385 3.066 2.084 2.478 ]
Residual stream:
[-0.479 -0.396 1.529 -0.096 0.081 -0.499 1.124 0.045 -0.270 0.128 0.549 ]
+
[ 0.037 -2.954 2.795 -3.880 -0.782 2.984 5.344 -3.385 3.066 2.084 2.478 ]
=
[-0.442 -3.351 4.324 -3.976 -0.702 2.485 6.468 -3.340 2.796 2.211 3.026 ]
Dot product with Linear layer rows:
[ 0.239 0.271 0.532 ]
[ 1.243 -0.713 -0.397 ]
[-0.493 0.565 -0.249 ]
[ 0.383 -0.504 -0.201 ]
[-0.195 0.647 -0.911 ]
[-0.628 0.174 0.301 ]
[-0.813 -0.158 0.252 ]
[ 0.675 -0.162 -0.705 ]
[-0.009 0.023 0.282 ]
[-0.432 0.349 0.003 ]
[-0.068 -1.026 1.120 ]
And add Bias:
[ 0.443 -0.362 0.360 ]
To get the final Logits:
[-17.608 3.578 10.739 ]
Input: [aa|b]
Next token probabilities:
[a] -> 0.000000
[|] -> 0.000776
[b] -> 0.999224