Coreml iOS

How To Train Your Core ML Model On Device

This is the sequel to the previous post on How To Create Updatable Core ML Models With Core ML Tools

With Core ML 3, training a Core ML model on a device is as easy as How To Train Your Dragon!

Inspired From How To Train Your Dragon


  • Get the Updatable mlmodel from the previous post
  • Just drag the mlmodel into your Xcode Project
  • Our Goal

    • Retrain a Cat vs Dog Classifier Core ML Model on Device by relabelling predicted images with the opposite label.
    • Train the batch of relabelled images on the device itself with our updatable model.
    • Save the new updated Model in the Application’s Document Directory on your device and use this new model for future predictions and/or retraining

    So for every predicted image, we’ll allow the user to classify it as the opposite label. A Predicted Cat can be relabelled as a Dog and vice versa. We’ll then retrain our model with these new images and labels.

    Before we dive into code, let’s get a hang of the Core ML classes and API we’ll be using.

    A Brief Look Into the Core ML API

    MLModel is the class that encapsulates the model.


    MLFeatureValue acts as a wrapper for the data. Core ML Model accepts the inputs and outputs in the form of MLFeatureValue.

    MLFeatureValue lets us directly use a CGImage. Along with that, we can pass the image constraints for the model. It creates the CVPixelBuffer from the CGImage for you thereby avoiding the need to write helper methods.

    This is how an MLFeatureValue instance is created from the image.

    let featureValue = try MLFeatureValue(cgImage: image.cgImage!, constraint: imageConstraint, options: nil)

    Now let’s look into MLImageConstraints.


    MLImageConstraints is responsible for feeding the correct size of the input image to the model.
    It contains the input information. In our case that is the image size and image format.
    We can easily retrieve the image constraint object from the model using the following piece of code:

    let imageConstraint = model?.modelDescription.inputDescriptionsByName["image"]!.imageConstraint!

    We just need to pass the input name (“image” in our case”) to the model description.


    An MLFeatureValue is not directly passed into the model. It needs to be wrapped inside the MLFeatureProvider.

    If you inspect the mlmodel Swift File, the Model implements the MLFeatureProvider protocol.
    To access the MLFeatureValue from MLFeatureProvider, there is a featureValue accessor method.

    MLDictionaryFeatureProvider is a convenience wrapper that holds the data in a dictionary format.
    It requires the input name (“image” in our case) as the key and MLFeatureValue as the value.
    If there are more than inputs, just add them in the same dictionary.


    MLBatchProvider holds a collection MLFeatureProviders for batch processing. We can hence predict multiple feature providers or train a batch of training inputs encapsulated in the MLBatchProvider.
    In this article, we’ll be doing the latter.

    An MLArrayBatchProviders contains an array of batch providers.


    An MLUpdateTask is responsible for updating the model with the new training inputs.

    Required Parameters

    • Model URL – The location of the compiled model (mlmodelc extension)
    • Training Data – MLArrayBatchProviders
    • Model Configuration – Here we pass MLModelConfiguration. We can use the existing models configuration or customize it. Example, we can force the model to run on CPU and/or GPU and/or neural engine.
    • Completion Handler – It returns the context from which we can access the updated model. Then we can write than model to the URL back or however you’d want to handle that part

    Optional Parameters

    • progressHandlers – Here you pass MLUpdateProgressHandlers with the array of events you want to listen to. Events like epoch start, training start etc.
    • progressHandler – This gets called whenever any of the events defined in the first case gets triggered.

    To start the training, just call the resume() function on the updateTask instance.

    Here’s a look at a pseudo code for training the data on a device:

    let updateTask = try MLUpdateTask(forModelAt: updatableModelURL,
                                      trainingData: trainingData,
                                      configuration: model.configuration,
                                      completionHandler: { context in
      // Training completed

    Now that we’ve got an idea of the different components and their roles, let’s build our iOS Application that trains the model on the device.


    Our Storyboard


    Load A Model From A URL

    First, let’s try to load our mlmodel into the Documents Directory on a separate URL:

    private func loadModel(url: URL) -> MLModel? {
          do {
            let config = MLModelConfiguration()
            config.computeUnits = .all
            return try MLModel(contentsOf: url, configuration: config)
          } catch {
            print("Error loading model: \(error)")
            return nil
    let modelURL = Bundle.main.url(forResource: "CatDogUpdatable", withExtension: "mlmodelc")
    let updatableModel = loadModel(url: modelURL)

    Predict An Image From MLModel

    Now that we’ve got our MLModel from the URL, we’ll run the prediction code assuming we’ve got the image from the ImagePickerController.

    func predict(image: UIImage) -> Animal? {
            let imageConstraint = model.modelDescription.inputDescriptionsByName["image"]!.imageConstraint! 
                let imageOptions: [MLFeatureValue.ImageOption: Any] = [
                    .cropAndScale: VNImageCropAndScaleOption.scaleFill.rawValue
                let featureValue = try MLFeatureValue(cgImage: image.cgImage!, constraint: imageConstraint, options: imageOptions)
                let featureProviderDict = try MLDictionaryFeatureProvider(dictionary: ["image" : featureValue])
                let prediction = try updatableModel?.prediction(from: featureProviderDict)
                let value = prediction?.featureValue(for: "classLabel")?.stringValue
                if value == "Dog"{
                    return .dog
                    return .cat
            }catch(let error){
                print("error is \(error.localizedDescription)")
            return nil

    We just pass in the UIImage as a CGImage to the MLFeatureValue with the MLImageConstraints of the Model input and MLDictionaryFeatureProvider runs the prediction on the MLModel.
    featureValue returns a set of featureNames. classLabel in our case contains the label cat or dog.

    We have a lookup dictionary of UIImage and Label termed as imageLabelDictionary.
    If we want to add an image to the training input, we set the image and the inverse of the predicted output(cat/dog) in the dictionary.

    Next, we create a batch provider out of the imageLabelDictionary.

    Create A Batch Provider

    Our Batch Provider creates an MLArrayBatchProvider out of the TrainingInput class which requires the image as a CVPixelBuffer and the classLabel string

    private func batchProvider() -> MLArrayBatchProvider
            var batchInputs: [MLFeatureProvider] = []
            let imageOptions: [MLFeatureValue.ImageOption: Any] = [
              .cropAndScale: VNImageCropAndScaleOption.scaleFill.rawValue
            for (image,label) in imageLabelDictionary {
                    let featureValue = try MLFeatureValue(cgImage: image.cgImage!, constraint: imageConstraint, options: imageOptions)
                    if let pixelBuffer = featureValue.imageBufferValue{
                        let x = CatDogUpdatableTrainingInput(image: pixelBuffer, classLabel: label)
                catch(let error){
                    print("error description is \(error.localizedDescription)")
         return MLArrayBatchProvider(array: batchInputs)

    Thanks the MLFeatureValue, we can easily retrieve the pixelBuffer from the featureValue function.

    Retrieve URL of the MLModel

    We need to pass the Model URL to MLUpdateTask. For that, we need to retrieve the URL from the Application’s Document’s Directory. We need to use the FileManager. The code is straightforward:

    let fileManager = FileManager.default
                let documentDirectory = try fileManager.url(for: .documentDirectory, in: .userDomainMask, appropriateFor:nil, create:true)
                let modelURL = documentDirectory.appendingPathComponent("CatDog.mlmodelc")

    Now we are ready to train our model again with the new images.

    Train Your Model Using MLUpdateTask

    let modelConfig = MLModelConfiguration()
    modelConfig.computeUnits = .cpuAndGPU
    let updateTask = try MLUpdateTask(forModelAt: modelURL, trainingData: batchProvider(), configuration: modelConfig,
                                 progressHandlers: MLUpdateProgressHandlers(forEvents: [.trainingBegin,.epochEnd],
                                  progressHandler: { (contextProgress) in
                                 }) { (finalContext) in
                                    if finalContext.task.error?.localizedDescription == nil
                                        let fileManager = FileManager.default
                                        do {
                                            let documentDirectory = try fileManager.url(for: .documentDirectory, in: .userDomainMask, appropriateFor:nil, create:true)
                                            let fileURL = documentDirectory.appendingPathComponent("CatDog.mlmodelc")
                                            try finalContext.model.write(to: fileURL)
                                            self.updatableModel = self.loadModel(url: fileURL)
                                        } catch(let error) {
                                            print("error is \(error.localizedDescription)")

    finalContext.model.write(to: fileURL) writes back the updated model back to the file URL.
    Note: We’d set the number of epochs to 1 in the previous tutorial.

    The output of the application in action is given below:


    As you can see in the above gif, we retrain some of the images which were showing the incorrect label.
    The next time when we run the prediction, they showed a different label. This does not guarantee that relabelling and retraining changes the predicted output always. It depends on a number of factors, confidence is one of them.

    That sums up this article on Core ML On Device Training. You can download the full source code from this Github Link.

    By Anupam Chugh

    iOS Developer exploring the depths of ML and AR on Mobile.
    Loves writing about thoughts, technology, and code.

    3 replies on “How To Train Your Core ML Model On Device”

    Leave a Reply

    Your email address will not be published. Required fields are marked *