Protein Structure Prediction: Simplified AlphaFold1
Due Date: April 7th, before midnight
In this assignment, you will implement a simplified version of the AlphaFold protein structure method. In particular, we will predict the distance matrix and the dihedral angles (the angle prediction is optional), instead of the full 3D coordinate structure.
For training, testing and validation we will make use of the SidechainNet Data, which is an extension of ProteinNet by including torsion angles, and side-chain information as well. In particular, you will work with the smaller CASP7 set (but you can also use the larger CASP12 set). The CASP7 dataset has over 10K structures (training_30).
SidechainNet data includes the sequence of the protein, PSSM info, secondary structure info, backbone coords (atomic coords of all atoms), a boolean mask that indicates whether the atomic coords are present or not, backbone torsion angles, sidechain angle and coords, and so on. On campuswire I'll share the starter script that will allow you to extract this information from the CASP7 train, test and validation sets.
The input to your method will be the training, validation and testing files from SidechainNet. Given a protein sequence $S$ of length $L$ from the training set, you will read the input features from the one-hot and PSSM info, etc, to create the $L \times L \times f$ tensor for the sequence $S$, where $f$ is the number of features per $(i,j)$ pair in $S$. For creating $f$, you can concatenate the features (one-hot, PSSM, info-content, etc.) for position $i$ and $j$, and you can also add their element-wise product and absolute value of the difference. At least try concatenation. See if the other alternatives improve the prediction.
You will next implement the residual block framework as described in the AlphaFold1 paper. However, you need not train on a 220 layer network. Rather you will train on several block groups, where each group (of 4 blocks) cycles through the dilations of 1, 2, 4, and 8. You should make the "number of block groups" an input parameter. So if we use 2 block groups, then your network will be trained on 8 blocks with two cycles of dilations. Each block should be made up of the different batch-norm, ELU, projections and dilations as described. These layers/activations will make use of the pytorch inbuilt functions, so you have to only define the architecture and forward function. For the main architecture of the Alphafold1 refer to Extended Data Fig 1b in the AlphaFold paper.
You will train on each $64 \times 64$ crop separately. You can create crops by starting at a randomly chosen position $(i,j)$, with $j>i$ and then generating all tiles with a stride of 64 for non-overlapping crops (or 16 or 32 if you want to have overlaps between crops); this also assumes that the input tensor has been zero padded as appropriate. You should restrict $i$ and $j$ to be within the first 64 entries (after zero-padding) along each dimension, to generate crops that cover the entire protein. In particular, different epochs should start at different $(i,j)$ positions for the same protein. After predicting the distances for the crops, you will compare with the true distances for that crop only. You will first need to discretize the distances between $2-22 A^\circ$ into 64 equal bins, with the last bin denoting distances greater than 21.6875 $A^circ$, since each bin has width $20/64 = 0.3125$. Thus, you have a total of 64 different distance symbols. Then you can use cross-entropy loss on the predicted probabilities and true distance symbols. The second head will directly predict the phi,psi angles per position, discretized into 1296 bins.
You can monitor the predictions on the validation set for hyperparameter tuning or early stopping in terms of loss.
For testing you should report the loss, but also the accuracy of contact prediction. That is, a pair $(i,j)$ is in contact if the true distance is below $8A^\circ$. So, for each test sequence, you should sum up the probabilities in the bins corresponding to the "symbols" in the range $2-8A^\circ$, and if that is over 0.5 then you can predict that pair to be in contact. You can then report the accuracy for short, medium and long range contacts in terms of distance in the sequence. Short is defined as $6 \le |i-j| \le 11$, medium as $12 \le |i-j| \le 23$ and long as $|i-j| \ge 24$. In each category, you should report the accuracy for the top $L$, $L/2$ and $L/5$ predictions, where $L$ is the sequence length. Finally, you should average the accuracy in each range over all of the test protein and report that number.
Note: accuracy (or precision) is defined as: TP/(TP+FP), where TP is the number of true positives and FP is the number of false positives. TP means that the predicted and true contacts agree on a pair, and FP means that you predict a contact, whereas there is no true contact for that pair. Also, if there are fewer than L (or L/2 or L/5) try contacts, you should use only the true number when evaluating the metrics, e.g., if $L=100$ but there are only 90 true contacts, then report results only for top 90 (not top 100), and so on.
So for any L/k value, first find your top L/k predictions and find the number of correct L/k predictions. Next divide that by L/k. Since there is a chance for fewer than L/k true contacts, esp for long range, do the following:
Accuracy for L/k: denom = min(L/k, #of true contacts in group -- short, medium, long) sort predictions in decreasing order and select top denom preds num correct= how many of your preds are true/correct accuracy @ L/k: num correct / denom
Submit you notebook (or python script) via submitty, along with an output file (txt/pdf) that summarizes the results of your method in terms of training and testing accuracy values. If submitting a notebook, results can be part of the notebook. You should report test loss, and contact map accuracy for CASP7 (if interested, do CASP12 too).
You may want to use multiple GPUs to speed up your training, using the DCS cluster.