library(torch)
scaled_dot_product_attention <- function(Q, K, V, mask = NULL) {
d_k <- Q$shape[length(Q$shape)]
scores <- Q$matmul(K$transpose(-2, -1)) / sqrt(d_k)
if (!is.null(mask)) {
scores <- scores$masked_fill(mask, -Inf)
}
weights <- nnf_softmax(scores, dim = -1)
list(weights$matmul(V), weights)
}
TransformerBlock <- nn_module(
initialize = function(embed_dim, num_heads, ff_dim) {
self$attn <- nn_multihead_attention(embed_dim, num_heads, batch_first = TRUE)
self$ff <- nn_sequential(
nn_linear(embed_dim, ff_dim),
nn_relu(),
nn_linear(ff_dim, embed_dim)
)
self$norm1 <- nn_layer_norm(embed_dim)
self$norm2 <- nn_layer_norm(embed_dim)
},
forward = function(x, attn_mask = NULL) {
attn_out <- self$attn(x, x, x, attn_mask = attn_mask)[[1]]
x <- self$norm1(x + attn_out)
x <- self$norm2(x + self$ff(x))
x
}
)Implementing a GPT Language Model
We implement a minimal GPT-style language model from scratch in torch and train it on a character-level text corpus. We then cover text generation strategies: temperature scaling, top-k, and top-p sampling. The topic closes with text generation strategies: temperature scaling and top-k sampling.
A minimal GPT
We assemble the components from the previous session into a complete GPT-style model. The model takes a sequence of token indices as input and produces a distribution over the vocabulary at each position, from which the next token can be sampled.
We reuse the TransformerBlock and scaled_dot_product_attention defined in the previous session.
The GPT class adds token embeddings, positional embeddings, a stack of transformer blocks, a final layer normalisation, and a linear head that projects to vocabulary logits. The causal mask is constructed inside forward from the current sequence length.
GPT <- nn_module(
initialize = function(vocab_size, embed_dim, num_heads, num_layers, max_seq_len) {
self$tok_emb <- nn_embedding(vocab_size, embed_dim)
self$pos_emb <- nn_embedding(max_seq_len, embed_dim)
self$blocks <- nn_module_list(lapply(
seq_len(num_layers),
function(i) TransformerBlock(embed_dim, num_heads, ff_dim = 4L * embed_dim)
))
self$ln <- nn_layer_norm(embed_dim)
self$head <- nn_linear(embed_dim, vocab_size, bias = FALSE)
self$max_seq_len <- max_seq_len
},
forward = function(idx) {
B <- idx$shape[1]
T <- idx$shape[2]
tok <- self$tok_emb(idx)
pos <- self$pos_emb(torch_arange(1, T, dtype = torch_long(), device = idx$device))
x <- tok + pos
mask <- torch_triu(
torch_ones(T, T, device = idx$device), diagonal = 1
)$to(dtype = torch_bool())
for (i in seq_along(self$blocks)) {
x <- self$blocks[[i]](x, attn_mask = mask)
}
x <- self$ln(x)
self$head(x)
}
)Preparing text data
We use the Tiny Shakespeare dataset, a widely used benchmark for character-level language models.
url <- "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
full_text <- readLines(url, warn = FALSE)
full_text <- paste(full_text, collapse = "\n")
cat(sprintf("Full corpus: %d characters\n", nchar(full_text)))Full corpus: 1115393 characters
cat(substr(full_text, 1, 200), "\n")First Citizen:
Before we proceed any further, hear me speak.
All:
Speak, speak.
First Citizen:
You are all resolved rather to die than to famish?
All:
Resolved. resolved.
First Citizen:
First, you
The full corpus is about 1 million characters. For this rendered document we use a short excerpt to keep execution time manageable. In a live session, use the full text for a model that learns more coherent Shakespeare-like patterns.
text <- substr(full_text, 1, 10000) # use first 10k characters for rendering
# replace with full_text in a live sessionWe build the character vocabulary and encode the corpus as a sequence of integers.
chars <- sort(unique(strsplit(text, "")[[1]]))
vocab_size <- length(chars)
stoi <- setNames(seq_along(chars), chars)
itos <- setNames(chars, as.character(seq_along(chars)))
data <- torch_tensor(
as.integer(stoi[strsplit(text, "")[[1]]]),
dtype = torch_long()
)
cat(sprintf("Vocab size: %d\n", vocab_size))Vocab size: 57
cat(sprintf("Encoded length: %d\n", data$shape[1]))Encoded length: 10000
We create input-target pairs by extracting overlapping windows of length seq_len. The target at each position is the token one step ahead.
seq_len <- 64L
data_vec <- as.integer(data)
n <- length(data_vec) - seq_len
X_mat <- matrix(0L, nrow = n, ncol = seq_len)
y_mat <- matrix(0L, nrow = n, ncol = seq_len)
for (i in seq_len(n)) {
X_mat[i, ] <- data_vec[i:(i + seq_len - 1L)]
y_mat[i, ] <- data_vec[(i + 1L):(i + seq_len)]
}
X <- torch_tensor(X_mat, dtype = torch_long())
y <- torch_tensor(y_mat, dtype = torch_long())
X$shape[1] 9936 64
y$shape[1] 9936 64
dataset <- tensor_dataset(X, y)
loader <- dataloader(dataset, batch_size = 64L, shuffle = TRUE)Training
We instantiate a small model — small enough to train on CPU in a short time.
model <- GPT(
vocab_size = vocab_size,
embed_dim = 64L,
num_heads = 4L,
num_layers = 2L,
max_seq_len = seq_len
)
sum(sapply(model$parameters, function(p) prod(p$shape)))[1] 111488
The training loop is identical to the one from day one. The loss is averaged across all positions in every sequence.
criterion <- nn_cross_entropy_loss()
optimizer <- optim_adam(model$parameters, lr = 3e-3)
for (epoch in seq_len(5)) {
model$train()
epoch_loss <- 0
coro::loop(for (batch in loader) {
optimizer$zero_grad()
logits <- model(batch[[1]]) # [B, T, vocab_size]
loss <- criterion(
logits$view(c(-1L, vocab_size)), # [B*T, vocab_size]
batch[[2]]$view(-1L) # [B*T]
)
loss$backward()
optimizer$step()
epoch_loss <- epoch_loss + loss$item()
})
cat(sprintf("Epoch %d: loss=%.4f\n", epoch, epoch_loss / length(loader)))
}Epoch 1: loss=2.4202
Epoch 2: loss=1.7198
Epoch 3: loss=1.1202
Epoch 4: loss=0.6756
Epoch 5: loss=0.4246
The logits tensor has shape [B, T, vocab_size]. We reshape it to [B*T, vocab_size] before passing to nn_cross_entropy_loss, which expects a 2D tensor of scores and a 1D tensor of target indices.
Text generation
To generate text, we feed a seed sequence into the model, sample the next token from the predicted distribution, append it to the sequence, and repeat.
generate <- function(model, seed_ids, max_new_tokens, temperature = 1.0, top_k = NULL) {
model$eval()
idx <- seed_ids$unsqueeze(1) # add batch dim: [1, T]
with_no_grad({
for (i in seq_len(max_new_tokens)) {
T_curr <- idx$shape[2]
start <- max(1L, T_curr - model$max_seq_len + 1L)
idx_cond <- idx[, start:T_curr]
T_cond <- idx_cond$shape[2]
logits <- model(idx_cond)[, T_cond, ] # logits at last position: [1, vocab]
logits <- logits / temperature
if (!is.null(top_k)) {
top_vals <- torch_topk(logits, top_k)[[1]]
threshold <- top_vals[, top_k]
logits <- logits$masked_fill(logits < threshold, -Inf)
}
probs <- nnf_softmax(logits, dim = -1)
next_id <- torch_multinomial(probs, 1L)
idx <- torch_cat(list(idx, next_id), dim = 2L)
}
})
idx$squeeze(1) # remove batch dim: [T + max_new_tokens]
}seed <- torch_tensor(
as.integer(stoi[strsplit("ROMEO:", "")[[1]]]),
dtype = torch_long()
)
ids <- generate(model, seed, max_new_tokens = 200L, temperature = 1.0)
cat(paste(itos[as.character(as.integer(ids))], collapse = ""), "\n")ROMEO:
Ank my cormon. What incy
Whould sand of Rans, ching the generan, you maliciouty as ca standees?
First Citizen:
He's Jus know at sowe, whow'th the know't.
First Citizen:
Nack but spow fort, but it,
The output will not be coherent Shakespeare when trained on a small excerpt. With the full corpus and more training time, the model produces recognisable dialogue structure.
Temperature
Temperature \(\tau\) scales the logits before the softmax. High temperature (\(\tau > 1\)) flattens the distribution, making all tokens more equally likely and output more random. Low temperature (\(\tau < 1\)) sharpens the distribution, making output more repetitive.
\[p_i = \frac{e^{z_i / \tau}}{\sum_j e^{z_j / \tau}}\]
generate(model, seed, max_new_tokens = 100L, temperature = 0.5) # sharp, repetitive
generate(model, seed, max_new_tokens = 100L, temperature = 2.0) # flat, randomTop-k sampling
Top-k sampling restricts sampling to the \(k\) most probable tokens at each step. All other tokens are assigned zero probability.
ids <- generate(model, seed, max_new_tokens = 200L, temperature = 1.0, top_k = 10L)
cat(paste(itos[as.character(as.integer(ids))], collapse = ""), "\n")ROMEO: the leg, toe!
Paicius to the not Carsen: hat unt, greaves en endiculaves strong
Appeal him? aude tire my to likkinke honest wese for his goody,
The not le ave strong are them, noby the nobilly,
And