why models (such as llms) shouldn’t train on their own generated data
This is a very simple, intuitive example on why having a model training on its own data can be detrimental. How it extends to LLMs requires additional complexities, and there are good technical papers you should read. This is meant for a quick explanation for the layman.
simplification
To explore what might happen of we train an LLM on its own data, we need to make some simplifications. LLM outputs long sequences of tokens, thus, two points of simplifications naturally arise: length and width. On the length side, rather than having long context / big-n-gram, we just have a unigram model. On the width side, rather than having 1000s of distinct tokens, we just have 2 tokens, 0 and 1. Thus, we have a simple unigram model that outputs strings of text, consisting of only 0s and 1s.
training on its own generations
The process of training on its own data is then very simple:
- start with a particular parameter for the unigram (e.g. p = 0.5 for either outputting 0 or 1)
- generate a text of 0s and 1s of some length, say 10
- fit the unigram probability on the text being generated
- repeat 1 2 3 for as many times, say 60, and track the parameter over time
In code it looks something like this:
def experiment_rollout(coin_prob=0.5, k_toss=10, n_rounds=60):
results = []
for i in range(n_rounds):
# add prob to results
results.append(coin_prob)
# use coin_prob to simulate k_toss of 0s and 1s
tosses = np.random.choice([0, 1], size=k_toss, p=[1-coin_prob, coin_prob])
# compute the new prob from the result of the tosses
print (tosses)
coin_prob = np.mean(tosses)
return results
results
Simulating the above rollouts 100 times we get the following plot, on x-axis is the number of times the model is trained on its own data, on the y-axis is the value of the parameter of the model, p.
It is easy to see that the models all degrade into either outputting all 0s and 1s toward the end, as p becomes polarized to either 1.0 or 0.0 over time, by training on its own data.
It is also easy to see that, if we generate shorter texts (say instead of length of 10, we generate text of length 1), the degradation can happen very quickly due to chance, and the reverse (say generate text of length 1000) is also true, that they degrade more slowly, as 1000 tosses will approximate the value of the original p very well.
Thus, for an LLM with many parameters and that generates very long sequences, it may degrade very slowly, but surely. The noise will stack up over time in a positive feedback loop, before ossifying the entire model.
additional reading
For additional reading, look up the polya urn model.
thanks for reading
— evan