More notes on GPT-2 fine-tuning
I'm currently fine-tuning GPT-2 on the full 43MB catalog of Accidental Tech Podcast transcripts. It's running on a GCP TPUv2-8 (and previously on a v3-8).
Some learnings and notes:
I had a hell of a time getting something to train on a TPU in GCP. Google's recommended way of provisioning TPUs, the ctpu command, never set things up correctly for me. I have a feeling it wasn't loading the specified version of Tensorflow (1.15) onto the TPU. Thanks to a tip from Shawn, I switched to provisioning using gcloud commands. I've also used the web UX.
I started using higher batch numbers during training to take advantage of the TPU's ample memory. On TPUv3-8, I used a batch size of 8, while on v2-8 I had to reduce it to 4. I'm interested to see the impact of batch sizes > 1 on the sample output. Really, I'm itching to see metrics on TPU utilization. I have no idea if I'm using it to its full potential.
I couldn't get a TPU on Colab all weekend. Even on GCP last night, I couldn't get a TPUv3-8 in my region and had to use a v2-8.
I used GCP's Cloud Shell for my first fine-tuning run but didn't realize that it disconnected after an hour of inactivity. I started executing the fine-tuning inside of tmux on the GCP VM so I don't have to worry about ssh disconnects messing everything up.
My GPT-2 checkpoint lives on the VM's disk and is uploaded to the TPU each time I train. Something had chipped away at the VM's memory over the weekend, and last night my training process kept getting killed due to OOM while loading the checkpoints from disk. A restart of the VM fixed things, but the point is that I should be keeping my checkpoint in GCP Storage instead and just passing a link to the TPU.
I was able to fine-tune on the 43 MB dataset for a big chunk of the weekend/Monday. However, I noticed last night that, over a full workday, the loss had only dropped 0.1 (2.9 to 2.8). I haven't experienced this yet in my previous runs and wonder if it's due to increased batch sizes. I've reduced the learning rate from 0.0001 to 0.00005 to see if that helps. (Update from this evening: loss has dropped to 2.4 over the course of the day.)