Training m6AnetΒΆ

m6Anet expects a training config file and a model config file, both on TOML format. We have provided examples of the training config and model config file in:

  • m6anet/m6anet/model/configs/model_configs/prod_pooling.toml
  • m6anet/m6anet/model/configs/training_configs/oversampled.toml

Below is the content of oversampled.toml

[loss_function]
loss_function_type = "binary_cross_entropy_loss"

[dataset]
root_dir = "/path/to/m6anet-dataprep/output"
min_reads = 20
norm_path = "/path/to/m6anet/m6anet/model/norm_factors/norm_dict.joblib"
num_neighboring_features = 1

[dataloader]
    [dataloader.train]
    batch_size = 256
    sampler = "ImbalanceOverSampler"

    [dataloader.val]
    batch_size = 256
    shuffle = false

    [dataloader.test]
    batch_size = 256
    shuffle = false

User can modify some basic training information such as the batch_size, the number of neighboring features, as well as the minimum number of reads per site to train m6Anet. We have also calculated the normalization factors required under norm_path variable. In principle, one can even change the loss_function_type by choosing one from m6anet/m6anet/utils/loss_functions.py or defining a new one. Sampler can be set to ImbalanceOverSampler (in which the model will perform oversampling to tackle the data imbalance with m6Anet modification) or any other sampler from m6anet/m6anet/utils/data_utils.py

The training script will look for data.readcount.labelled file and data.index file under the root_dir directory. While data.index can be obtained by running m6anet-dataprep over nanopolish eventalign.txt file, data.readcount.labelled must be supplied by the user by adding extra columns to the data.readcount file produced by m6anet-dataprep. Additionally, data.readcount.labelled must be of the following format:

transcript_id   transcript_position n_reads modification_status set_type
ENST00000361055 549                 11      0                   Train
ENST00000361055 554                 12      0                   Train
ENST00000475035 133                  3      0                   Train
ENST00000222329 309                 11      0                   Val
ENST00000222329 2496                15      0                   Val
ENST00000222329 2631                23      0                   Val
ENST00000523944 72                   1      0                   Test
ENST00000523944 2196                14      0                   Test

Here modification status tells the model which positions are modified and which positions are not modified. The column set_type informs the training script which part of the data we should train on and which part of the data should be used for validation and testing purpose. Lastly, n_reads corresponds to the number of reads that comes from the corresponding transcript positions and any sites with n_reads less than the min_reads specified in he training config file will not be used for training validation, or testing. We have also provided an example of data.readcount.labelled in m6anet/demo/ folder.

Below is the content of prod_pooling.toml:

model = "prod_sigmoid_pooling"

[[block]]
block_type = "DeaggregateNanopolish"
num_neighboring_features = 1

[[block]]
block_type = "KmerMultipleEmbedding"
input_channel = 66
output_channel = 2
num_neighboring_features = 1

[[block]]
block_type = "ConcatenateFeatures"

[[block]]
block_type = "Linear"
input_channel = 15
output_channel = 150
activation = "relu"
batch_norm = true

[[block]]
block_type = "Linear"
input_channel = 150
output_channel = 32
activation = "relu"
batch_norm = false

[[block]]
block_type = "SigmoidProdPooling"
input_channel = 32
n_reads_per_site = 20

The training script will build the model block by block. For additional information on the block type, please check the source code under m6anet/m6anet/model/model_blocks

In order to train m6Anet, please change the root_dir variable inside prod_pooling.toml to m6anet/demo/. Afterwards, run m6anet-dataprep:

m6anet-dataprep --eventalign m6anet/demo/eventalign.txt \
               --out_dir m6anet/demo/ --n_processes 4

This will produce data.index file and data.json file that will be used for the script to access the preprocessed data Next, to train m6Anet using the demo data, run:

m6anet-train --model_config m6anet/model/configs/model_configs/prod_pooling.toml --train_config ../m6anet/model/configs/training_configs/oversampled.toml --save_dir /path/to/save_dir --device cpu --lr 0.0001 --seed 25 --epochs 30 --num_workers 4 --save_per_epoch 1 --num_iterations 5

The model will be trained on cpu for 30 epochs and we will save the model states every 1 epoch. One can replace the device argument with cuda to train with GPU. For complete description of the command line arguments, please see Command line arguments page