Training a model from text
This tutorial walks you through the training and using of a machine learning neural network model to classify newsgroup posts into twenty different categories. This makes use of a classical dataset in machine learning, often used for educational purposes.
Note: this tutorial deals with bag of words
(BOW) models, another page is dedicated to the novel character-based models.
In summary, a repository contains 20 repositories of text files, each being a newsgroup post.
Getting the dataset
Let us create a dedicated repository
mkdir models
mkdir models/n20
The data can be obtained either from http://www.deepdetect.com/dd/examples/all/n20/news20.tar.bz2
cd models/n20
wget http://www.deepdetect.com/dd/examples/all/n20/news20.tar.bz2
tar xvjf news20.tar.bz2
You can take a look at the raw data:
less news20/sci_crypt/000000616.eml
There are around 20000 files in the dataset.
Creating the machine learning service
The first step with DeepDetect is to start the server:
./dede
and create a machine learning service that uses a multi-layered perceptron with 200 hidden neurons in 2 layers, and using relu activations:
curl -X PUT "http://localhost:8080/services/n20" -d '{
"mllib":"caffe",
"description":"newsgroup classification service",
"type":"supervised",
"parameters":{
"input":{
"connector":"txt"
},
"mllib":{
"template":"mlp",
"nclasses":20,
"layers":[200,200],
"activation":"relu"
}
},
"model":{
"templates":"../templates/caffe/",
"repository":"models/n20"
}
}'
yields:
{
"status":{
"code":201,
"msg":"Created"
}
}
Training and testing the service
Let us now train a statistical model in the form of the neural network defined above. Below is a full API call for launching an asynchronous training call on the GPU (with automatic fallback on the CPU if no GPU present). Take a look at it, and before proceeding with the call, let us review the call in details below. We train on 80% of the dataset, and test on the remaining 20%.
curl -X POST "http://localhost:8080/train" -d '{
"service":"n20",
"async":true,
"parameters":{
"mllib":{
"gpu":true,
"solver":{
"iterations":2000,
"test_interval":200,
"base_lr":0.05
},
"net":{
"batch_size":300
}
},
"input":{
"shuffle":true,
"test_split":0.2,
"min_count":10,
"min_word_length":5,
"count":false
},
"output":{
"measure":["mcll","f1"]
}
},
"data":["models/n20/news20"]
}'
First and foremost, we are using our newly created service to train a model. This means that our service will be busy for some time, and we cannot use it for anything else but reviewing the training call status and progress. Other services, if any, would remain available of course. In more details here:
async
allows to start a non-blocking (i.e. asynchronous call)gpu
allows to tell the server we would like to use the GPU. Importantly note that in the absence of GPU, the server will automatically fall back on the CPU, without warningiterations
is the number of training iterations after which the training will terminate automatically. Until termination it is possible to get the status and progress of the call, as we will demonstrate belowlabel_offset
tells the CSV input connectors that the label identifiers run from 1 to 7 instead of 0 to 6. This is required here in order to not miss a classmin_count
rejects the words that do not appear often enoughmin_word_length
rejects the words with length below the specified limitcount
determines whether to build a counter for each word or use 0 and 1 onlymeasures
lists the assessment metrics of the model being built,mcll
for multi-class log loss andf1
for F1-scoredata
holds the dataset repository
For more details on the training phase options and parameters, see the API.
Let us now run the call above, the immediate answer is:
{
"status":{
"code":201,
"msg":"Created"
},
"head":{
"method":"/train",
"job":1,
"status":"running"
}
}
indicating that the call was successful and the training is now running.
You can get the status of the call anytime with another call:
curl -X GET "http://localhost:8080/train?service=n20&job=1"
yields:
{
"status":{
"msg": "OK",
"code": 200
},
"body":{
"parameters":{
"mllib":{
"batch_size": 359
}
},
"measure":{
"f1": 0.8919178423728972,
"train_loss": 0.0016851313412189484,
"mcll": 0.5737156999301365,
"recall": 0.8926410552973584,
"iteration": 1999.0,
"precision": 0.8911958003860988,
"accp": 0.8936339522546419
}
},
"head":{
"status": "finished",
"job": 1,
"method": "/train",
"time": 541.0
}
}
Using the service
You can get predictions on text files and raw text very easily:
curl -X POST 'http://localhost:8080/predict' -d '{
"service":"n20",
"parameters":{
"mllib":{
"gpu":true
}
},
"data":["my computer runs linux"]
}'
yield response
{
"status":{
"code":200,
"msg":"OK"
},
"head":{
"method":"/predict",
"time":226.0,
"service":"n20"
},
"body":{
"predictions":{
"uri":"0",
"classes":{
"last":true,
"prob":0.3948741555213928,
"cat":"comp_graphics"
}
}
}
}
Restarting and using the service
=> Importantly, when re-creating the service you need to use a call that does not override your existing model architecture:
curl -X PUT "http://localhost:8080/services/n20" -d '{
"mllib":"caffe",
"description":"newsgroup classification service",
"type":"supervised",
"parameters":{
"input":{
"connector":"txt"
},
"mllib":{
"nclasses":20
}
},
"model":{
"repository":"models/n20"
}
}'
The call above does not specify the template
parameter anymore since your model has already been specified and trained.
When the model has been overriden by mistake, the typical error you’d get from the server is
{
"status":{
"code":500,
"msg":"InternalError",
"dd_code":1007,
"dd_msg":"./include/caffe/llogging.h:66 / Fatal Caffe error"
}
}
This means Caffe was not able to load the trained model as it doesn’t fit the model architecture anymore.
Tips: if the above error happens and you do not wish to go through a full training step again, go to the model repository, copy away the .caffemodel
file(s) and start from scratch again, including training, but change iterations
to a small number, like 1
or 2
. Then go back to the model repository, remove the new caffemodel
files, copy back the old ones, kill the service, and re-create it with the PUT
call above. You should be safe then.