1"""Convert a saved model to tflite model. 2 3Usage: python3 saved-model-to-tflite.py <mlgo saved_model_dir> <tflite dest_dir> 4 5The <tflite dest_dir> will contain: 6 model.tflite: this is the converted saved model 7 output_spec.json: the output spec, copied from the saved_model dir. 8""" 9 10import tensorflow as tf 11import os 12import sys 13from tf_agents.policies import greedy_policy 14 15 16def main(argv): 17 assert len(argv) == 3 18 sm_dir = argv[1] 19 tfl_dir = argv[2] 20 tf.io.gfile.makedirs(tfl_dir) 21 tfl_path = os.path.join(tfl_dir, 'model.tflite') 22 converter = tf.lite.TFLiteConverter.from_saved_model(sm_dir) 23 converter.target_spec.supported_ops = [ 24 tf.lite.OpsSet.TFLITE_BUILTINS, 25 ] 26 tfl_model = converter.convert() 27 with tf.io.gfile.GFile(tfl_path, 'wb') as f: 28 f.write(tfl_model) 29 30 json_file = 'output_spec.json' 31 src_json = os.path.join(sm_dir, json_file) 32 if tf.io.gfile.exists(src_json): 33 tf.io.gfile.copy(src_json, 34 os.path.join(tfl_dir, json_file)) 35 36if __name__ == '__main__': 37 main(sys.argv) 38