From 86b38ce38b4790759922cad02bd4d5d56e86d2b6 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Wed, 23 Mar 2022 15:08:11 +0100 Subject: [PATCH 001/154] ENH: add TFRecord class + to_tfrecords method + split "array" data from "scalar" data + add max_nb_of_samples for Dataset --- python/otbtf.py | 241 +++++++++++++++++++++++++++++++++++++++++++++-- python/system.py | 68 +++++++++++++ 2 files changed, 301 insertions(+), 8 deletions(-) create mode 100644 python/system.py diff --git a/python/otbtf.py b/python/otbtf.py index a23d5237..45e13791 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -18,17 +18,23 @@ # # ==========================================================================*/ """ -Contains stuff to help working with TensorFlow and geospatial data in the -OTBTF framework. +Contains stuff to help working with TensorFlow and geospatial data in the OTBTF framework. """ +import glob +import json +import os import threading import multiprocessing import time import logging from abc import ABC, abstractmethod +from functools import partial +from tqdm import tqdm + import numpy as np import tensorflow as tf import gdal +import system # ----------------------------------------------------- Helpers -------------------------------------------------------- @@ -167,13 +173,18 @@ class PatchesImagesReader(PatchesReaderBase): :see PatchesReaderBase """ - def __init__(self, filenames_dict: dict, use_streaming=False): + def __init__(self, filenames_dict, scalar_dict=None, use_streaming=False): """ :param filenames_dict: A dict() structured as follow: {src_name1: [src1_patches_image_1.tif, ..., src1_patches_image_N.tif], src_name2: [src2_patches_image_1.tif, ..., src2_patches_image_N.tif], ... src_nameM: [srcM_patches_image_1.tif, ..., srcM_patches_image_N.tif]} + :param scalar_dict: (optional) a dict containing list of scalars (int, float, str) as follow: + {scalar_name1: ["value_1", ..., "value_N"], + scalar_name2: [value_1, ..., value_N], + ... + scalar_nameN: [value1, ..., value_N]} :param use_streaming: if True, the patches are read on the fly from the disc, nothing is kept in memory. """ @@ -182,8 +193,13 @@ class PatchesImagesReader(PatchesReaderBase): # gdal_ds dict self.gdal_ds = {key: [gdal_open(src_fn) for src_fn in src_fns] for key, src_fns in filenames_dict.items()} + # Scalar parameters (e.g. metadatas) + self.scalar_dict = scalar_dict + if scalar_dict is None: + self.scalar_dict = {} + # check number of patches in each sources - if len({len(ds_list) for ds_list in self.gdal_ds.values()}) != 1: + if len({len(ds_list) for ds_list in list(self.gdal_ds.values()) + list(self.scalar_dict.values())}) != 1: raise Exception("Each source must have the same number of patches images") # streaming on/off @@ -213,6 +229,12 @@ class PatchesImagesReader(PatchesReaderBase): if not self.use_streaming: patches_list = {src_key: [read_as_np_arr(ds) for ds in self.gdal_ds[src_key]] for src_key in self.gdal_ds} self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=0) for src_key in self.gdal_ds} + # Create a scalars dict so that one scalar <-> one patch + self.scalar_buffer = {} + for src_key, scalars in self.scalar_dict.items(): + self.scalar_buffer[src_key] = [] + for scalar, ds_size in zip(scalars, self.ds_sizes): + self.scalar_buffer[src_key].extend([scalar] * ds_size) def _get_ds_and_offset_from_index(self, index): offset = index @@ -254,9 +276,11 @@ class PatchesImagesReader(PatchesReaderBase): if not self.use_streaming: res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.gdal_ds} + res.update({key: np.asarray(scalars[index]) for key, scalars in self.scalar_buffer.items()}) else: i, offset = self._get_ds_and_offset_from_index(index) res = {src_key: self._read_extract_as_np_arr(self.gdal_ds[src_key][i], offset) for src_key in self.gdal_ds} + res.update({key: np.asarray(scalars[i]) for key, scalars in self.scalar_dict.items()}) return res @@ -362,16 +386,18 @@ class Dataset: """ def __init__(self, patches_reader: PatchesReaderBase, buffer_length: int = 128, - Iterator: IteratorBase = RandomIterator): + Iterator=RandomIterator, max_nb_of_samples=None): """ :param patches_reader: The patches reader instance :param buffer_length: The number of samples that are stored in the buffer :param Iterator: The iterator class used to generate the sequence of patches indices. + :param max_nb_of_samples: Optional, max number of samples to consider """ # patches reader self.patches_reader = patches_reader - self.size = self.patches_reader.get_size() + self.size = min(self.patches_reader.get_size(), + max_nb_of_samples) if max_nb_of_samples else self.patches_reader.get_size() # iterator self.iterator = Iterator(patches_reader=self.patches_reader) @@ -380,6 +406,7 @@ class Dataset: self.output_types = dict() self.output_shapes = dict() one_sample = self.patches_reader.get_sample(index=0) + print(one_sample) # DEBUG for src_key, np_arr in one_sample.items(): self.output_shapes[src_key] = np_arr.shape self.output_types[src_key] = tf.dtypes.as_dtype(np_arr.dtype) @@ -404,6 +431,14 @@ class Dataset: output_types=self.output_types, output_shapes=self.output_shapes).repeat(1) + def to_tfrecords(self, output_dir, n_samples_per_shard=100, drop_remainder=True): + """ + + """ + tfrecord = TFRecords(output_dir) + tfrecord.ds2tfrecord(self, n_samples_per_shard=n_samples_per_shard, drop_remainder=drop_remainder) + + def get_stats(self) -> dict: """ :return: the dataset statistics, computed by the patches reader @@ -502,8 +537,8 @@ class DatasetFromPatchesImages(Dataset): :see Dataset """ - def __init__(self, filenames_dict: dict, use_streaming: bool = False, buffer_length: int = 128, - Iterator: IteratorBase = RandomIterator): + def __init__(self, filenames_dict, use_streaming=False, buffer_length: int = 128, + Iterator=RandomIterator): """ :param filenames_dict: A dict() structured as follow: {src_name1: [src1_patches_image1, ..., src1_patches_imageN1], @@ -518,3 +553,193 @@ class DatasetFromPatchesImages(Dataset): patches_reader = PatchesImagesReader(filenames_dict=filenames_dict, use_streaming=use_streaming) super().__init__(patches_reader=patches_reader, buffer_length=buffer_length, Iterator=Iterator) + + +class TFRecords: + """ + This class allows to convert Dataset objects to TFRecords and to load them in dataset tensorflows format. + """ + + def __init__(self, path): + """ + :param path: Can be a directory where TFRecords must be saved/loaded or a single TFRecord path + """ + if system.is_dir(path) or not os.path.exists(path): + self.dirpath = path + system.mkdir(self.dirpath) + self.tfrecords_pattern_path = "{}*.records".format(system.pathify(self.dirpath)) + else: + self.dirpath = system.dirname(path) + self.tfrecords_pattern_path = path + self.output_types_file = "{}output_types.json".format(system.pathify(self.dirpath)) + self.output_shape_file = "{}output_shape.json".format(system.pathify(self.dirpath)) + self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None + self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None + + def _bytes_feature(self, value): + """ + Used to convert a value to a type compatible with tf.train.Example. + :param value: value + :return a bytes_list from a string / byte. + """ + if isinstance(value, type(tf.constant(0))): + value = value.numpy() # BytesList won't unpack a string from an EagerTensor. + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + def ds2tfrecord(self, dataset, n_samples_per_shard=100, drop_remainder=True): + """ + Convert and save samples from dataset object to tfrecord files. + :param dataset: Dataset object to convert into a set of tfrecords + :param n_samples_per_shard: Number of samples per shard + :param drop_remainder: Whether additional samples should be dropped. Advisable if using multiworkers training. + If True, all TFRecords will have `n_samples_per_shard` samples + """ + logging.info("%s samples", dataset.size) + + nb_shards = (dataset.size // n_samples_per_shard) + if not drop_remainder and dataset.size % n_samples_per_shard > 0: + nb_shards += 1 + + self.convert_dataset_output_shapes(dataset) + + def _convert_data(data): + """ + Convert data + """ + data_converted = {} + + for k, d in data.items(): + data_converted[k] = d.name + + return data_converted + + self.save(_convert_data(dataset.output_types), self.output_types_file) + + for i in tqdm(range(nb_shards)): + + if (i + 1) * n_samples_per_shard <= dataset.size: + nb_sample = n_samples_per_shard + else: + nb_sample = dataset.size - i * n_samples_per_shard + + filepath = "{}{}.records".format(system.pathify(self.dirpath), i) + with tf.io.TFRecordWriter(filepath) as writer: + for s in range(nb_sample): + sample = dataset.read_one_sample() + serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()} + features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in + serialized_sample.items()} + tf_features = tf.train.Features(feature=features) + example = tf.train.Example(features=tf_features) + writer.write(example.SerializeToString()) + + @staticmethod + def save(data, filepath): + """ + Save data to pickle format. + :param data: Data to save json format + :param filepath: Output file name + """ + + with open(filepath, 'w') as f: + json.dump(data, f, indent=4) + + @staticmethod + def load(filepath): + """ + Return data from pickle format. + :param filepath: Input file name + """ + with open(filepath, 'r') as f: + return json.load(f) + + def convert_dataset_output_shapes(self, dataset): + """ + Convert and save numpy shape to tensorflow shape. + :param dataset: Dataset object containing output shapes + """ + output_shapes = {} + + for key in dataset.output_shapes.keys(): + output_shapes[key] = (None,) + dataset.output_shapes[key] + + self.save(output_shapes, self.output_shape_file) + + @staticmethod + def parse_tfrecord(example, features_types, target_keys): + """ + Parse example object to sample dict. + :param example: Example object to parse + :param features_types: List of types for each feature + :param target_keys: list of keys of the targets + """ + read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types} + example_parsed = tf.io.parse_single_example(example, read_features) + + for key in read_features.keys(): + example_parsed[key] = tf.io.parse_tensor(example_parsed[key], out_type=features_types[key]) + + # Differentiating inputs and outputs + input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys} + target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} + + return input_parsed, target_parsed + + + def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None): + """ + Read all tfrecord files matching with pattern and convert data to tensorflow dataset. + :param batch_size: Size of tensorflow batch + :param target_key: Key of the target, e.g. 's2_out' + :param n_workers: number of workers, e.g. 4 if using 4 GPUs + e.g. 12 if using 3 nodes of 4 GPUs + :param drop_remainder: whether the last batch should be dropped in the case it has fewer than + `batch_size` elements. True is advisable when training on multiworkers. + False is advisable when evaluating metrics so that all samples are used + :param shuffle_buffer_size: is None, shuffle is not used. Else, blocks of shuffle_buffer_size + elements are shuffled using uniform random. + """ + options = tf.data.Options() + if shuffle_buffer_size: + options.experimental_deterministic = False # disable order, increase speed + options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO # for multiworker + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) + + # TODO: to be investigated : + # 1/ num_parallel_reads useful ? I/O bottleneck of not ? + # 2/ num_parallel_calls=tf.data.experimental.AUTOTUNE useful ? + # 3/ shuffle or not shuffle ? + matching_files = glob.glob(self.tfrecords_pattern_path) + logging.info('Searching TFRecords in %s...', self.tfrecords_pattern_path) + logging.info('Number of matching TFRecords: %s', len(matching_files)) + matching_files = matching_files[:n_workers * (len(matching_files) // n_workers)] # files multiple of workers + nb_matching_files = len(matching_files) + if nb_matching_files == 0: + raise Exception("At least one worker has no TFRecord file in {}. Please ensure that the number of TFRecord " + "files is greater or equal than the number of workers!".format(self.tfrecords_pattern_path)) + logging.info('Reducing number of records to : %s', nb_matching_files) + dataset = tf.data.TFRecordDataset(matching_files) # , num_parallel_reads=2) # interleaves reads from xxx files + dataset = dataset.with_options(options) # uses data as soon as it streams in, rather than in its original order + dataset = dataset.map(parse, num_parallel_calls=tf.data.experimental.AUTOTUNE) + if shuffle_buffer_size: + dataset = dataset.shuffle(buffer_size=shuffle_buffer_size) + dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) + dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) + # TODO voir si on met le prefetch avant le batch cf https://keras.io/examples/keras_recipes/tfrecord/ + + return dataset + + def read_one_sample(self, target_keys): + """ + Read one tfrecord file matching with pattern and convert data to tensorflow dataset. + :param target_key: Key of the target, e.g. 's2_out' + """ + matching_files = glob.glob(self.tfrecords_pattern_path) + one_file = matching_files[0] + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) + dataset = tf.data.TFRecordDataset(one_file) + dataset = dataset.map(parse) + dataset = dataset.batch(1) + + sample = iter(dataset).get_next() + return sample diff --git a/python/system.py b/python/system.py new file mode 100644 index 00000000..e7b581fd --- /dev/null +++ b/python/system.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) 2020-2022 INRAE + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +"""Various system operations""" +import logging +import pathlib +import os + +# ---------------------------------------------------- Helpers --------------------------------------------------------- + +def pathify(pth): + """ Adds posix separator if needed """ + if not pth.endswith("/"): + pth += "/" + return pth + + +def mkdir(pth): + """ Create a directory """ + path = pathlib.Path(pth) + path.mkdir(parents=True, exist_ok=True) + + +def dirname(filename): + """ Returns the parent directory of the file """ + return str(pathlib.Path(filename).parent) + + +def basic_logging_init(): + """ basic logging initialization """ + logging.basicConfig( + format='%(asctime)s %(levelname)-8s %(message)s', + level=logging.INFO, + datefmt='%Y-%m-%d %H:%M:%S') + + +def logging_info(msg, verbose=True): + """ + Prints log info only if required by `verbose` + :param msg: message to log + :param verbose: boolean. Whether to log msg or not. Default True + :return: + """ + if verbose: + logging.info(msg) + +def is_dir(filename): + """ return True if filename is the path to a directory """ + return os.path.isdir(filename) -- GitLab From 28e489468ece5134fd441df1609a8ded77a95531 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Thu, 24 Mar 2022 16:43:46 +0100 Subject: [PATCH 002/154] ENH: split the Dataset initialisation and the patch_reader feeding --- python/otbtf.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 45e13791..8f8f9359 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -385,7 +385,7 @@ class Dataset: :see Buffer """ - def __init__(self, patches_reader: PatchesReaderBase, buffer_length: int = 128, + def __init__(self, patches_reader: PatchesReaderBase = None, buffer_length: int = 128, Iterator=RandomIterator, max_nb_of_samples=None): """ :param patches_reader: The patches reader instance @@ -393,8 +393,19 @@ class Dataset: :param Iterator: The iterator class used to generate the sequence of patches indices. :param max_nb_of_samples: Optional, max number of samples to consider """ - # patches reader + if patches_reader: + self.feed(patches_reader, buffer_length, Iterator, max_nb_of_samples) + + + def feed(self, patches_reader: PatchesReaderBase = None, buffer_length: int = 128, + Iterator=RandomIterator, max_nb_of_samples=None): + """ + :param patches_reader: The patches reader instance + :param buffer_length: The number of samples that are stored in the buffer + :param Iterator: The iterator class used to generate the sequence of patches indices. + :param max_nb_of_samples: Optional, max number of samples to consider + """ self.patches_reader = patches_reader self.size = min(self.patches_reader.get_size(), max_nb_of_samples) if max_nb_of_samples else self.patches_reader.get_size() -- GitLab From c774e9b62687d07abeb9a8829aace58a2da8a0c0 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Wed, 30 Mar 2022 14:48:09 +0200 Subject: [PATCH 003/154] ENH: add more log info --- python/otbtf.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 8f8f9359..17fd06c0 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -407,8 +407,14 @@ class Dataset: :param max_nb_of_samples: Optional, max number of samples to consider """ self.patches_reader = patches_reader - self.size = min(self.patches_reader.get_size(), - max_nb_of_samples) if max_nb_of_samples else self.patches_reader.get_size() + + # If necessary, limit the nb of samples + logging.info('There are %s samples available', self.patches_reader.get_size()) + if max_nb_of_samples and self.patches_reader.get_size() > max_nb_of_samples: + logging.info('Reducing number of samples to %s', max_nb_of_samples) + self.size = max_nb_of_samples + else: + self.size = self.patches_reader.get_size() # iterator self.iterator = Iterator(patches_reader=self.patches_reader) -- GitLab From 120a7ba69b75400dc56c8aa70ff5902cb8a91cc0 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Mon, 4 Apr 2022 16:36:03 +0200 Subject: [PATCH 004/154] FIX: remove API breaker --- python/otbtf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 17fd06c0..d13d5788 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -173,19 +173,19 @@ class PatchesImagesReader(PatchesReaderBase): :see PatchesReaderBase """ - def __init__(self, filenames_dict, scalar_dict=None, use_streaming=False): + def __init__(self, filenames_dict, use_streaming=False, scalar_dict=None): """ :param filenames_dict: A dict() structured as follow: {src_name1: [src1_patches_image_1.tif, ..., src1_patches_image_N.tif], src_name2: [src2_patches_image_1.tif, ..., src2_patches_image_N.tif], ... src_nameM: [srcM_patches_image_1.tif, ..., srcM_patches_image_N.tif]} + :param use_streaming: if True, the patches are read on the fly from the disc, nothing is kept in memory. :param scalar_dict: (optional) a dict containing list of scalars (int, float, str) as follow: {scalar_name1: ["value_1", ..., "value_N"], scalar_name2: [value_1, ..., value_N], ... - scalar_nameN: [value1, ..., value_N]} - :param use_streaming: if True, the patches are read on the fly from the disc, nothing is kept in memory. + scalar_nameM: [value1, ..., valueN]} """ assert len(filenames_dict.values()) > 0 -- GitLab From 817a69c7806e7dda1976916339743c0067d78304 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Mon, 4 Apr 2022 16:55:39 +0200 Subject: [PATCH 005/154] FIX: add tqdm dependency --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 34b9b4a4..ece649b3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -25,7 +25,7 @@ RUN if $GUI; then \ RUN ln -s /usr/bin/python3 /usr/local/bin/python && ln -s /usr/bin/pip3 /usr/local/bin/pip # NumPy version is conflicting with system's gdal dep and may require venv ARG NUMPY_SPEC="==1.19.*" -RUN pip install --no-cache-dir -U pip wheel mock six future deprecated "numpy$NUMPY_SPEC" \ +RUN pip install --no-cache-dir -U pip wheel mock six future tqdm deprecated "numpy$NUMPY_SPEC" \ && pip install --no-cache-dir --no-deps keras_applications keras_preprocessing # ---------------------------------------------------------------------------- -- GitLab From 0da55b2105cb256d2f6212ebe06033002a8c0164 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Mon, 4 Apr 2022 17:02:24 +0200 Subject: [PATCH 006/154] REFAC: replace system by plain python os library --- python/otbtf.py | 15 +++++------ python/system.py | 68 ------------------------------------------------ 2 files changed, 7 insertions(+), 76 deletions(-) delete mode 100644 python/system.py diff --git a/python/otbtf.py b/python/otbtf.py index d13d5788..c702375e 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -34,7 +34,6 @@ from tqdm import tqdm import numpy as np import tensorflow as tf import gdal -import system # ----------------------------------------------------- Helpers -------------------------------------------------------- @@ -581,15 +580,15 @@ class TFRecords: """ :param path: Can be a directory where TFRecords must be saved/loaded or a single TFRecord path """ - if system.is_dir(path) or not os.path.exists(path): + if os.path.isdir(path) or not os.path.exists(path): self.dirpath = path - system.mkdir(self.dirpath) - self.tfrecords_pattern_path = "{}*.records".format(system.pathify(self.dirpath)) + os.makedirs(self.dirpath, exist_ok=True) + self.tfrecords_pattern_path = os.path.join(self.dirpath, "*.records") else: - self.dirpath = system.dirname(path) + self.dirpath = os.path.dirname(path) self.tfrecords_pattern_path = path - self.output_types_file = "{}output_types.json".format(system.pathify(self.dirpath)) - self.output_shape_file = "{}output_shape.json".format(system.pathify(self.dirpath)) + self.output_types_file = os.path.join(self.dirpath, "output_types.json") + self.output_shape_file = os.path.join(self.dirpath, "output_shape.json") self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None @@ -639,7 +638,7 @@ class TFRecords: else: nb_sample = dataset.size - i * n_samples_per_shard - filepath = "{}{}.records".format(system.pathify(self.dirpath), i) + filepath = os.path.join(self.dirpath, f"{i}.records") with tf.io.TFRecordWriter(filepath) as writer: for s in range(nb_sample): sample = dataset.read_one_sample() diff --git a/python/system.py b/python/system.py deleted file mode 100644 index e7b581fd..00000000 --- a/python/system.py +++ /dev/null @@ -1,68 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Copyright (c) 2020-2022 INRAE - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" -"""Various system operations""" -import logging -import pathlib -import os - -# ---------------------------------------------------- Helpers --------------------------------------------------------- - -def pathify(pth): - """ Adds posix separator if needed """ - if not pth.endswith("/"): - pth += "/" - return pth - - -def mkdir(pth): - """ Create a directory """ - path = pathlib.Path(pth) - path.mkdir(parents=True, exist_ok=True) - - -def dirname(filename): - """ Returns the parent directory of the file """ - return str(pathlib.Path(filename).parent) - - -def basic_logging_init(): - """ basic logging initialization """ - logging.basicConfig( - format='%(asctime)s %(levelname)-8s %(message)s', - level=logging.INFO, - datefmt='%Y-%m-%d %H:%M:%S') - - -def logging_info(msg, verbose=True): - """ - Prints log info only if required by `verbose` - :param msg: message to log - :param verbose: boolean. Whether to log msg or not. Default True - :return: - """ - if verbose: - logging.info(msg) - -def is_dir(filename): - """ return True if filename is the path to a directory """ - return os.path.isdir(filename) -- GitLab From e82173f7268ba63dea90510b2503fd72fb7fa1e1 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Tue, 12 Apr 2022 12:24:38 +0200 Subject: [PATCH 007/154] FIX: make `_read_extract_as_np_arr` method return 3D arrays even for singleband --- python/otbtf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/otbtf.py b/python/otbtf.py index c702375e..71e2f6a3 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -258,6 +258,9 @@ class PatchesImagesReader(PatchesReaderBase): buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz) if len(buffer.shape) == 3: buffer = np.transpose(buffer, axes=(1, 2, 0)) + else: # single-band raster + buffer = np.expand_dims(buffer, axis=2) + return np.float32(buffer) def get_sample(self, index): -- GitLab From 4ab9bdb59daae0f6adfbec9990bd4ff9b970398f Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Tue, 19 Apr 2022 16:41:39 +0200 Subject: [PATCH 008/154] ENH: add the possibility to specify cropping of the target when reading TFRecords --- python/otbtf.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 61a0767f..598eab65 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -685,12 +685,13 @@ class TFRecords: self.save(output_shapes, self.output_shape_file) @staticmethod - def parse_tfrecord(example, features_types, target_keys): + def parse_tfrecord(example, features_types, target_keys, target_cropping=None): """ Parse example object to sample dict. :param example: Example object to parse :param features_types: List of types for each feature :param target_keys: list of keys of the targets + :param target_cropping: Optional. Number of pixels to be removed on each side of the target tensor. """ read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types} example_parsed = tf.io.parse_single_example(example, read_features) @@ -702,27 +703,33 @@ class TFRecords: input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys} target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} + if target_cropping: + print({key: value for key, value in target_parsed.items()}) + target_parsed = {key: value[target_cropping:-target_cropping, target_cropping:-target_cropping, :] for key, value in target_parsed.items()} + return input_parsed, target_parsed - def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None): + def read(self, batch_size, target_keys, target_cropping=None, n_workers=1, drop_remainder=True, shuffle_buffer_size=None): """ Read all tfrecord files matching with pattern and convert data to tensorflow dataset. :param batch_size: Size of tensorflow batch - :param target_key: Key of the target, e.g. 's2_out' + :param target_keys: Keys of the target, e.g. ['s2_out'] + :param target_cropping: Number of pixels to be removed on each side of the target. Must be used with a network + architecture coherent with this, i.e. that has a Cropping2D layer in the end :param n_workers: number of workers, e.g. 4 if using 4 GPUs e.g. 12 if using 3 nodes of 4 GPUs :param drop_remainder: whether the last batch should be dropped in the case it has fewer than `batch_size` elements. True is advisable when training on multiworkers. False is advisable when evaluating metrics so that all samples are used - :param shuffle_buffer_size: is None, shuffle is not used. Else, blocks of shuffle_buffer_size + :param shuffle_buffer_size: if None, shuffle is not used. Else, blocks of shuffle_buffer_size elements are shuffled using uniform random. """ options = tf.data.Options() if shuffle_buffer_size: options.experimental_deterministic = False # disable order, increase speed options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO # for multiworker - parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys, target_cropping=target_cropping) # TODO: to be investigated : # 1/ num_parallel_reads useful ? I/O bottleneck of not ? -- GitLab From 31a4f931a29899431e62cbaabd6d86747c3ecd37 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Tue, 19 Apr 2022 16:53:59 +0200 Subject: [PATCH 009/154] STYLE: remove debug prints --- python/otbtf.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 598eab65..c944728c 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -425,7 +425,6 @@ class Dataset: self.output_types = dict() self.output_shapes = dict() one_sample = self.patches_reader.get_sample(index=0) - print(one_sample) # DEBUG for src_key, np_arr in one_sample.items(): self.output_shapes[src_key] = np_arr.shape self.output_types[src_key] = tf.dtypes.as_dtype(np_arr.dtype) @@ -704,7 +703,6 @@ class TFRecords: target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} if target_cropping: - print({key: value for key, value in target_parsed.items()}) target_parsed = {key: value[target_cropping:-target_cropping, target_cropping:-target_cropping, :] for key, value in target_parsed.items()} return input_parsed, target_parsed -- GitLab From c3eb4703503a45e5cc8b9a4ea409c7341d46558d Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 10:49:10 +0200 Subject: [PATCH 010/154] ADD: modifications --- python/otbtf.py | 74 ++++++++++++++++++------------------------------- 1 file changed, 27 insertions(+), 47 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index c944728c..922d13db 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -172,7 +172,7 @@ class PatchesImagesReader(PatchesReaderBase): :see PatchesReaderBase """ - def __init__(self, filenames_dict, use_streaming=False, scalar_dict=None): + def __init__(self, filenames_dict, use_streaming=False, scalar_dict={}): """ :param filenames_dict: A dict() structured as follow: {src_name1: [src1_patches_image_1.tif, ..., src1_patches_image_N.tif], @@ -192,18 +192,18 @@ class PatchesImagesReader(PatchesReaderBase): # gdal_ds dict self.gdal_ds = {key: [gdal_open(src_fn) for src_fn in src_fns] for key, src_fns in filenames_dict.items()} - # Scalar parameters (e.g. metadatas) - self.scalar_dict = scalar_dict - if scalar_dict is None: - self.scalar_dict = {} + # streaming on/off + self.use_streaming = use_streaming + + # Scalar dict (e.g. for metadata) + # If the scalars are not numpy.ndarray, convert them + self.scalar_dict = {key: [i if isinstance(i, np.ndarray) else np.asarray(i) for i in scalars] + for key, scalars in scalar_dict.items()} # check number of patches in each sources if len({len(ds_list) for ds_list in list(self.gdal_ds.values()) + list(self.scalar_dict.values())}) != 1: raise Exception("Each source must have the same number of patches images") - # streaming on/off - self.use_streaming = use_streaming - # gdal_ds check nb_of_patches = {key: 0 for key in self.gdal_ds} self.nb_of_channels = dict() @@ -226,14 +226,8 @@ class PatchesImagesReader(PatchesReaderBase): # if use_streaming is False, we store in memory all patches images if not self.use_streaming: - patches_list = {src_key: [read_as_np_arr(ds) for ds in self.gdal_ds[src_key]] for src_key in self.gdal_ds} - self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=0) for src_key in self.gdal_ds} - # Create a scalars dict so that one scalar <-> one patch - self.scalar_buffer = {} - for src_key, scalars in self.scalar_dict.items(): - self.scalar_buffer[src_key] = [] - for scalar, ds_size in zip(scalars, self.ds_sizes): - self.scalar_buffer[src_key].extend([scalar] * ds_size) + self.patches_buffer = {src_key: np.concatenate([read_as_np_arr(ds) for ds in src_ds[src_key]], axis=0) for + src_key, src_ds in self.gdal_ds.items()} def _get_ds_and_offset_from_index(self, index): offset = index @@ -276,20 +270,19 @@ class PatchesImagesReader(PatchesReaderBase): assert index >= 0 assert index < self.size + i, offset = self._get_ds_and_offset_from_index(index) + res = {src_key: scalar[i] for src_key, scalar in self.scalar_dict.items()} if not self.use_streaming: - res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.gdal_ds} - res.update({key: np.asarray(scalars[index]) for key, scalars in self.scalar_buffer.items()}) + res.update({src_key: arr[index, :, :, :] for src_key, arr in self.patches_buffer.items()}) else: - i, offset = self._get_ds_and_offset_from_index(index) - res = {src_key: self._read_extract_as_np_arr(self.gdal_ds[src_key][i], offset) for src_key in self.gdal_ds} - res.update({key: np.asarray(scalars[i]) for key, scalars in self.scalar_dict.items()}) - + res.update({src_key: self._read_extract_as_np_arr(self.gdal_ds[src_key][i], offset) + for src_key in self.gdal_ds}) return res def get_stats(self): """ Compute some statistics for each source. - Depending if streaming is used, the statistics are computed directly in memory, or chunk-by-chunk. + When streaming is used, chunk-by-chunk. Else, the statistics are computed directly in memory. :return statistics dict """ @@ -340,6 +333,7 @@ class IteratorBase(ABC): """ Base class for iterators """ + @abstractmethod def __init__(self, patches_reader: PatchesReaderBase): pass @@ -396,22 +390,10 @@ class Dataset: :param max_nb_of_samples: Optional, max number of samples to consider """ # patches reader - if patches_reader: - self.feed(patches_reader, buffer_length, Iterator, max_nb_of_samples) - - - def feed(self, patches_reader: PatchesReaderBase = None, buffer_length: int = 128, - Iterator=RandomIterator, max_nb_of_samples=None): - """ - :param patches_reader: The patches reader instance - :param buffer_length: The number of samples that are stored in the buffer - :param Iterator: The iterator class used to generate the sequence of patches indices. - :param max_nb_of_samples: Optional, max number of samples to consider - """ self.patches_reader = patches_reader # If necessary, limit the nb of samples - logging.info('There are %s samples available', self.patches_reader.get_size()) + logging.info('Number of samples: %s', self.patches_reader.get_size()) if max_nb_of_samples and self.patches_reader.get_size() > max_nb_of_samples: logging.info('Reducing number of samples to %s', max_nb_of_samples) self.size = max_nb_of_samples @@ -451,14 +433,19 @@ class Dataset: def to_tfrecords(self, output_dir, n_samples_per_shard=100, drop_remainder=True): """ + Save the dataset into TFRecord files + :param output_dir: output directory + :param n_samples_per_shard: number of samples per TFRecord file + :param drop_remainder: drop remainder samples """ tfrecord = TFRecords(output_dir) tfrecord.ds2tfrecord(self, n_samples_per_shard=n_samples_per_shard, drop_remainder=drop_remainder) - def get_stats(self) -> dict: """ + Compute dataset statistics + :return: the dataset statistics, computed by the patches reader """ with self.mining_lock: @@ -684,13 +671,12 @@ class TFRecords: self.save(output_shapes, self.output_shape_file) @staticmethod - def parse_tfrecord(example, features_types, target_keys, target_cropping=None): + def parse_tfrecord(example, features_types, target_keys): """ Parse example object to sample dict. :param example: Example object to parse :param features_types: List of types for each feature :param target_keys: list of keys of the targets - :param target_cropping: Optional. Number of pixels to be removed on each side of the target tensor. """ read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types} example_parsed = tf.io.parse_single_example(example, read_features) @@ -702,19 +688,13 @@ class TFRecords: input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys} target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} - if target_cropping: - target_parsed = {key: value[target_cropping:-target_cropping, target_cropping:-target_cropping, :] for key, value in target_parsed.items()} - return input_parsed, target_parsed - - def read(self, batch_size, target_keys, target_cropping=None, n_workers=1, drop_remainder=True, shuffle_buffer_size=None): + def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None): """ Read all tfrecord files matching with pattern and convert data to tensorflow dataset. :param batch_size: Size of tensorflow batch :param target_keys: Keys of the target, e.g. ['s2_out'] - :param target_cropping: Number of pixels to be removed on each side of the target. Must be used with a network - architecture coherent with this, i.e. that has a Cropping2D layer in the end :param n_workers: number of workers, e.g. 4 if using 4 GPUs e.g. 12 if using 3 nodes of 4 GPUs :param drop_remainder: whether the last batch should be dropped in the case it has fewer than @@ -727,7 +707,7 @@ class TFRecords: if shuffle_buffer_size: options.experimental_deterministic = False # disable order, increase speed options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO # for multiworker - parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys, target_cropping=target_cropping) + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) # TODO: to be investigated : # 1/ num_parallel_reads useful ? I/O bottleneck of not ? -- GitLab From 5186eb5e0d6a771f437e7666b576fee7769c521c Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 11:38:47 +0200 Subject: [PATCH 011/154] ENH: use default arg as None instead {} --- python/otbtf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 922d13db..ad77b7a2 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -172,7 +172,7 @@ class PatchesImagesReader(PatchesReaderBase): :see PatchesReaderBase """ - def __init__(self, filenames_dict, use_streaming=False, scalar_dict={}): + def __init__(self, filenames_dict, use_streaming=False, scalar_dict=None): """ :param filenames_dict: A dict() structured as follow: {src_name1: [src1_patches_image_1.tif, ..., src1_patches_image_N.tif], @@ -198,7 +198,7 @@ class PatchesImagesReader(PatchesReaderBase): # Scalar dict (e.g. for metadata) # If the scalars are not numpy.ndarray, convert them self.scalar_dict = {key: [i if isinstance(i, np.ndarray) else np.asarray(i) for i in scalars] - for key, scalars in scalar_dict.items()} + for key, scalars in scalar_dict.items()} if scalar_dict else {} # check number of patches in each sources if len({len(ds_list) for ds_list in list(self.gdal_ds.values()) + list(self.scalar_dict.values())}) != 1: -- GitLab From 878dde7dd0ab49a28d18acf9ad5907e1a4e70f70 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 11:44:40 +0200 Subject: [PATCH 012/154] REFAC: change import order --- python/otbtf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index ad77b7a2..cbb96e55 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -29,11 +29,10 @@ import time import logging from abc import ABC, abstractmethod from functools import partial -from tqdm import tqdm - import numpy as np import tensorflow as tf from osgeo import gdal +from tqdm import tqdm # ----------------------------------------------------- Helpers -------------------------------------------------------- @@ -581,9 +580,10 @@ class TFRecords: self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None + @staticmethod def _bytes_feature(self, value): """ - Used to convert a value to a type compatible with tf.train.Example. + Convert a value to a type compatible with tf.train.Example. :param value: value :return a bytes_list from a string / byte. """ -- GitLab From 2cc5a837a0e45c223152313f8c1bb12c56e6fcd0 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 14:58:17 +0200 Subject: [PATCH 013/154] WIP: use godzilla runner --- .gitlab-ci.yml | 102 +++++++++++++++++++++++++++---------------------- 1 file changed, 57 insertions(+), 45 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 5cfe0b9e..d1517df4 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -18,96 +18,109 @@ stages: - Test - Applications Test -.update_otbtf_src: &update_otbtf_src - - sudo rm -rf $OTBTF_SRC && sudo ln -s $PWD $OTBTF_SRC # Replace local OTBTF source directory - -.compile_otbtf: &compile_otbtf - - cd $OTB_BUILD && sudo make install -j$(nproc --all) # Rebuild OTB with new OTBTF sources - -.install_pytest: &install_pytest - - pip3 install pytest pytest-cov pytest-order # Install pytest stuff - -before_script: - - *update_otbtf_src - -build: - stage: Build +Build the docker image: + stage: Docker build cpu-basic-dev allow_failure: false + tags: [godzilla] + image: docker/compose:1.29.2 + variables: + DOCKER_TLS_CERTDIR: "" + DOCKER_HOST: tcp://docker:2375 + services: + - name: docker:dind + command: [dockerd, '-H', 'tcp://0.0.0.0:2375'] + before_script: + # docker login asks for the password to be passed through stdin for security + # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab + # https://docs.gitlab.com/ce/ci/variables/predefined_variables.html + - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY script: - - *compile_otbtf + - docker info + - > + docker build + --pull + --cache-from $CI_REGISTRY_IMAGE:latest + --label "org.opencontainers.image.title=$CI_PROJECT_TITLE" + --label "org.opencontainers.image.url=$CI_PROJECT_URL" + --label "org.opencontainers.image.created=$CI_JOB_STARTED_AT" + --label "org.opencontainers.image.revision=$CI_COMMIT_SHA" + --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" + --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME + --build-arg BASE_IMAGE="ubuntu:20.04" + --build-arg BZL_CONFIGS="" + --build-arg KEEP_SRC_OTB=true + . + - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME + - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test + - docker push $CI_REGISTRY_IMAGE:cpu-basic-test -flake8: +.static_analysis_base: + image: $CI_REGISTRY_IMAGE:cpu-basic-test stage: Static Analysis allow_failure: true + +flake8: + extends: .static_analysis_base script: - sudo apt update && sudo apt install flake8 -y - python -m flake8 --max-line-length=120 $OTBTF_SRC/python pylint: - stage: Static Analysis - allow_failure: true + extends: .static_analysis_base script: - sudo apt update && sudo apt install pylint -y - pylint --disable=too-many-nested-blocks,too-many-locals,too-many-statements,too-few-public-methods,too-many-instance-attributes,too-many-arguments --ignored-modules=tensorflow --max-line-length=120 --logging-format-style=new $OTBTF_SRC/python codespell: - stage: Static Analysis - allow_failure: true + extends: .static_analysis_base script: - sudo pip install codespell && codespell cppcheck: - stage: Static Analysis - allow_failure: true + extends: .static_analysis_base script: - sudo apt update && sudo apt install cppcheck -y - cd $OTBTF_SRC/ && cppcheck --enable=all --error-exitcode=1 -I include/ --suppress=missingInclude --suppress=unusedFunction . +.tests_base: + image: $CI_REGISTRY_IMAGE:cpu-basic-test + artifacts: + paths: + - $ARTIFACT_TEST_DIR/*.* + expire_in: 1 week + when: on_failure + ctest: + extends: .tests_base stage: Test script: - - *compile_otbtf - - sudo rm -rf $OTB_TEST_DIR/* # Empty testing temporary folder (old files here) - cd $OTB_BUILD/ && sudo ctest -L OTBTensorflow # Run ctest after_script: - cp -r $OTB_TEST_DIR $ARTIFACT_TEST_DIR - artifacts: - paths: - - $ARTIFACT_TEST_DIR/*.* - expire_in: 1 week - when: on_failure .applications_test_base: + extends: .tests_base stage: Applications Test rules: # Only for MR targeting 'develop' and 'master' branches because applications tests are slow - if: $CI_MERGE_REQUEST_ID && $CI_MERGE_REQUEST_TARGET_BRANCH_NAME == 'develop' - if: $CI_MERGE_REQUEST_ID && $CI_MERGE_REQUEST_TARGET_BRANCH_NAME == 'master' - artifacts: - when: on_failure - paths: - - $CI_PROJECT_DIR/report_*.xml - - $ARTIFACT_TEST_DIR/*.* - expire_in: 1 week - + before_script: + - pip3 install pytest pytest-cov pytest-order + - mkdir -p $ARTIFACT_TEST_DIR + - cd $CI_PROJECT_DIR + crc_book: extends: .applications_test_base script: - - *compile_otbtf - - *install_pytest - - cd $CI_PROJECT_DIR - mkdir -p $CRC_BOOK_TMP - TMPDIR=$CRC_BOOK_TMP DATADIR=$CI_PROJECT_DIR/test/data python -m pytest --junitxml=$CI_PROJECT_DIR/report_tutorial.xml $OTBTF_SRC/test/tutorial_unittest.py after_script: - - mkdir -p $ARTIFACT_TEST_DIR - cp $CRC_BOOK_TMP/*.* $ARTIFACT_TEST_DIR/ sr4rs: extends: .applications_test_base script: - - *compile_otbtf - - *install_pytest - - cd $CI_PROJECT_DIR - wget -O sr4rs_sentinel2_bands4328_france2020_savedmodel.zip https://nextcloud.inrae.fr/s/EZL2JN7SZyDK8Cf/download/sr4rs_sentinel2_bands4328_france2020_savedmodel.zip - unzip -o sr4rs_sentinel2_bands4328_france2020_savedmodel.zip @@ -116,5 +129,4 @@ sr4rs: - rm -rf sr4rs - git clone https://github.com/remicres/sr4rs.git - export PYTHONPATH=$PYTHONPATH:$PWD/sr4rs - - python -m pytest --junitxml=$CI_PROJECT_DIR/report_sr4rs.xml $OTBTF_SRC/test/sr4rs_unittest.py - + - python -m pytest --junitxml=$ARTIFACT_TEST_DIR/report_sr4rs.xml $OTBTF_SRC/test/sr4rs_unittest.py -- GitLab From e91d542127736540185107b8e4103830ed712c57 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 15:00:23 +0200 Subject: [PATCH 014/154] WIP: use godzilla runner --- .gitlab-ci.yml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d1517df4..79cccb82 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,4 @@ -image: gitlab-registry.irstea.fr/remi.cresson/otbtf:3.1-cpu-basic-testing +image: $CI_REGISTRY_IMAGE:cpu-basic-test variables: OTB_BUILD: /src/otb/build/OTB/build # Local OTB build directory @@ -18,8 +18,8 @@ stages: - Test - Applications Test -Build the docker image: - stage: Docker build cpu-basic-dev +Build the cpu-basic-dev docker image: + stage: Build allow_failure: false tags: [godzilla] image: docker/compose:1.29.2 @@ -55,7 +55,6 @@ Build the docker image: - docker push $CI_REGISTRY_IMAGE:cpu-basic-test .static_analysis_base: - image: $CI_REGISTRY_IMAGE:cpu-basic-test stage: Static Analysis allow_failure: true @@ -83,7 +82,6 @@ cppcheck: - cd $OTBTF_SRC/ && cppcheck --enable=all --error-exitcode=1 -I include/ --suppress=missingInclude --suppress=unusedFunction . .tests_base: - image: $CI_REGISTRY_IMAGE:cpu-basic-test artifacts: paths: - $ARTIFACT_TEST_DIR/*.* -- GitLab From e1a552bc292433432ce3106f64af4885d53b9638 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 15:06:30 +0200 Subject: [PATCH 015/154] WIP: use godzilla runner --- .gitlab-ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 79cccb82..45e250bb 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -22,12 +22,12 @@ Build the cpu-basic-dev docker image: stage: Build allow_failure: false tags: [godzilla] - image: docker/compose:1.29.2 + image: docker:latest variables: DOCKER_TLS_CERTDIR: "" - DOCKER_HOST: tcp://docker:2375 + DOCKER_HOST: tcp://localhost:2375/ services: - - name: docker:dind + - name: docker:18.09.7-dind command: [dockerd, '-H', 'tcp://0.0.0.0:2375'] before_script: # docker login asks for the password to be passed through stdin for security -- GitLab From 57b58a39f33fc452ac7884f3a59ee89975726d0f Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 15:16:37 +0200 Subject: [PATCH 016/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 45e250bb..98d4dc8d 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -18,7 +18,7 @@ stages: - Test - Applications Test -Build the cpu-basic-dev docker image: +Build the CI docker image: stage: Build allow_failure: false tags: [godzilla] -- GitLab From 74bc99453336bd41fd8244773d9b4deaeb7e5514 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 15:21:26 +0200 Subject: [PATCH 017/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 98d4dc8d..e3b6a1e8 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -25,7 +25,7 @@ Build the CI docker image: image: docker:latest variables: DOCKER_TLS_CERTDIR: "" - DOCKER_HOST: tcp://localhost:2375/ + DOCKER_HOST: tcp://docker:2375/ services: - name: docker:18.09.7-dind command: [dockerd, '-H', 'tcp://0.0.0.0:2375'] -- GitLab From e77077d6dc1716bb4816ddf7974dc757c819b235 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 15:22:35 +0200 Subject: [PATCH 018/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index e3b6a1e8..98d4dc8d 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -25,7 +25,7 @@ Build the CI docker image: image: docker:latest variables: DOCKER_TLS_CERTDIR: "" - DOCKER_HOST: tcp://docker:2375/ + DOCKER_HOST: tcp://localhost:2375/ services: - name: docker:18.09.7-dind command: [dockerd, '-H', 'tcp://0.0.0.0:2375'] -- GitLab From a65dfa9c45362723737b6458d647fa51cb503d61 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 15:32:14 +0200 Subject: [PATCH 019/154] WIP: use godzilla runner --- .gitlab-ci.yml | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 98d4dc8d..b6a83be8 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -24,11 +24,16 @@ Build the CI docker image: tags: [godzilla] image: docker:latest variables: - DOCKER_TLS_CERTDIR: "" - DOCKER_HOST: tcp://localhost:2375/ + DOCKER_HOST: tcp://docker:2375/ + DOCKER_DRIVER: overlay2 + # See https://github.com/docker-library/docker/pull/166 + DOCKER_TLS_CERTDIR: " + services: - - name: docker:18.09.7-dind - command: [dockerd, '-H', 'tcp://0.0.0.0:2375'] + - name: docker:dind + entrypoint: ["env", "-u", "DOCKER_HOST"] + command: ["dockerd-entrypoint.sh"] + before_script: # docker login asks for the password to be passed through stdin for security # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab -- GitLab From 41bc4fbbba393da2cee25a43110feae1dfef4e0c Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 15:33:03 +0200 Subject: [PATCH 020/154] WIP: use godzilla runner --- .gitlab-ci.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index b6a83be8..bdacb6b1 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -27,8 +27,7 @@ Build the CI docker image: DOCKER_HOST: tcp://docker:2375/ DOCKER_DRIVER: overlay2 # See https://github.com/docker-library/docker/pull/166 - DOCKER_TLS_CERTDIR: " - + DOCKER_TLS_CERTDIR: "" services: - name: docker:dind entrypoint: ["env", "-u", "DOCKER_HOST"] -- GitLab From 67aa195945ee2734b2f823dc44c2702380b7036e Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 15:49:04 +0200 Subject: [PATCH 021/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index bdacb6b1..25355c11 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -22,7 +22,7 @@ Build the CI docker image: stage: Build allow_failure: false tags: [godzilla] - image: docker:latest + image: docker:17.06.0-ce variables: DOCKER_HOST: tcp://docker:2375/ DOCKER_DRIVER: overlay2 -- GitLab From 087f404b7528f925f37c9da98577443029976fa2 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 15:52:37 +0200 Subject: [PATCH 022/154] WIP: use godzilla runner --- .gitlab-ci.yml | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 25355c11..11dd9b0e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -22,17 +22,13 @@ Build the CI docker image: stage: Build allow_failure: false tags: [godzilla] - image: docker:17.06.0-ce + image: docker/compose:1.29.2 variables: - DOCKER_HOST: tcp://docker:2375/ - DOCKER_DRIVER: overlay2 - # See https://github.com/docker-library/docker/pull/166 DOCKER_TLS_CERTDIR: "" + DOCKER_HOST: tcp://docker:2375 services: - name: docker:dind - entrypoint: ["env", "-u", "DOCKER_HOST"] - command: ["dockerd-entrypoint.sh"] - + command: [dockerd, '-H', 'tcp://0.0.0.0:2375'] before_script: # docker login asks for the password to be passed through stdin for security # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab -- GitLab From 8d67b260e5ac4e9496f9bdca3cf8c9d38a1f4a57 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 16:01:20 +0200 Subject: [PATCH 023/154] WIP: use godzilla runner --- .gitlab-ci.yml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 11dd9b0e..ded529b7 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -23,12 +23,16 @@ Build the CI docker image: allow_failure: false tags: [godzilla] image: docker/compose:1.29.2 - variables: - DOCKER_TLS_CERTDIR: "" - DOCKER_HOST: tcp://docker:2375 services: - name: docker:dind - command: [dockerd, '-H', 'tcp://0.0.0.0:2375'] + entrypoint: ["env", "-u", "DOCKER_HOST"] + command: ["dockerd-entrypoint.sh"] + variables: + DOCKER_HOST: tcp://docker:2375/ + DOCKER_DRIVER: overlay2 + # See https://github.com/docker-library/docker/pull/166 + DOCKER_TLS_CERTDIR: "" + before_script: # docker login asks for the password to be passed through stdin for security # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab -- GitLab From 73919db73f0eb3220a2dafbda6b60c0dfaf6d668 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 16:07:24 +0200 Subject: [PATCH 024/154] WIP: use godzilla runner --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index ded529b7..c11aabec 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -37,6 +37,7 @@ Build the CI docker image: # docker login asks for the password to be passed through stdin for security # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab # https://docs.gitlab.com/ce/ci/variables/predefined_variables.html + - docker info - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY script: - docker info -- GitLab From c220de6727b37688d7a341f656c1428d30d75430 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 16:16:37 +0200 Subject: [PATCH 025/154] WIP: use godzilla runner --- .gitlab-ci.yml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c11aabec..a3c32e6a 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -25,14 +25,6 @@ Build the CI docker image: image: docker/compose:1.29.2 services: - name: docker:dind - entrypoint: ["env", "-u", "DOCKER_HOST"] - command: ["dockerd-entrypoint.sh"] - variables: - DOCKER_HOST: tcp://docker:2375/ - DOCKER_DRIVER: overlay2 - # See https://github.com/docker-library/docker/pull/166 - DOCKER_TLS_CERTDIR: "" - before_script: # docker login asks for the password to be passed through stdin for security # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab -- GitLab From 872022c5cbd811074cecfe170690910be12c76ed Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 16:28:09 +0200 Subject: [PATCH 026/154] WIP: use godzilla runner --- .gitlab-ci.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a3c32e6a..042b1bc5 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -22,9 +22,8 @@ Build the CI docker image: stage: Build allow_failure: false tags: [godzilla] - image: docker/compose:1.29.2 services: - - name: docker:dind + - name: docker:17.06.0-ce-dind before_script: # docker login asks for the password to be passed through stdin for security # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab -- GitLab From 50b2c2e1ac280077568f002fb89861844ddaf5f6 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 16:28:59 +0200 Subject: [PATCH 027/154] WIP: use godzilla runner --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 042b1bc5..17b869d7 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -22,6 +22,7 @@ Build the CI docker image: stage: Build allow_failure: false tags: [godzilla] + image: docker/compose:1.29.2 services: - name: docker:17.06.0-ce-dind before_script: -- GitLab From 8dd609bfd31d8b5ac9c1cf718135201e0b3cebb4 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 17:05:08 +0200 Subject: [PATCH 028/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 17b869d7..8cf0092e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -25,6 +25,8 @@ Build the CI docker image: image: docker/compose:1.29.2 services: - name: docker:17.06.0-ce-dind + alias: docker + before_script: # docker login asks for the password to be passed through stdin for security # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab -- GitLab From 4f051393032649d7ff1830ac85d66edcc7283e73 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 17:06:14 +0200 Subject: [PATCH 029/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 8cf0092e..b967a042 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -25,7 +25,7 @@ Build the CI docker image: image: docker/compose:1.29.2 services: - name: docker:17.06.0-ce-dind - alias: docker + alias: docker before_script: # docker login asks for the password to be passed through stdin for security -- GitLab From fa16bd9b49454f2573619c9ec6643ee3a1f19798 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 20:10:21 +0200 Subject: [PATCH 030/154] WIP: use godzilla runner --- .gitlab-ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index b967a042..d12a4ea7 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -45,9 +45,9 @@ Build the CI docker image: --label "org.opencontainers.image.revision=$CI_COMMIT_SHA" --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - --build-arg BASE_IMAGE="ubuntu:20.04" - --build-arg BZL_CONFIGS="" - --build-arg KEEP_SRC_OTB=true + --build-arg "BASE_IMAGE=ubuntu:20.04" + --build-arg "BZL_CONFIGS=" + --build-arg "KEEP_SRC_OTB=true" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test -- GitLab From 0105b69365bd4c8f158997343e143c7919afec3b Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 20:54:22 +0200 Subject: [PATCH 031/154] WIP: use godzilla runner --- .gitlab-ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d12a4ea7..f60f4cf9 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -45,9 +45,9 @@ Build the CI docker image: --label "org.opencontainers.image.revision=$CI_COMMIT_SHA" --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - --build-arg "BASE_IMAGE=ubuntu:20.04" - --build-arg "BZL_CONFIGS=" - --build-arg "KEEP_SRC_OTB=true" + --build-arg BASE_IMAGE=ubuntu:20.04 + --build-arg BZL_CONFIGS= + --build-arg KEEP_SRC_OTB=true . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test -- GitLab From c5b5bc1889e7b2fd0ca055c40243f2d27982442c Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 20:57:28 +0200 Subject: [PATCH 032/154] WIP: use godzilla runner --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index f60f4cf9..2cce1c6d 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -35,6 +35,7 @@ Build the CI docker image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY script: - docker info + - ls $PWD - > docker build --pull -- GitLab From af5addeccb8118623f6fc9ac7a52693178d02ae5 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 21:10:09 +0200 Subject: [PATCH 033/154] WIP: use godzilla runner --- .gitlab-ci.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 2cce1c6d..380420ab 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -46,9 +46,7 @@ Build the CI docker image: --label "org.opencontainers.image.revision=$CI_COMMIT_SHA" --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - --build-arg BASE_IMAGE=ubuntu:20.04 - --build-arg BZL_CONFIGS= - --build-arg KEEP_SRC_OTB=true + --build-arg "BASE_IMAGE=ubuntu:20.04" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test -- GitLab From 4db77eb700eb4b910c6ad40de71343ff0135a45d Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 21:10:56 +0200 Subject: [PATCH 034/154] WIP: use godzilla runner --- .gitlab-ci.yml | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 380420ab..b4e2cb66 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -38,15 +38,9 @@ Build the CI docker image: - ls $PWD - > docker build - --pull - --cache-from $CI_REGISTRY_IMAGE:latest - --label "org.opencontainers.image.title=$CI_PROJECT_TITLE" - --label "org.opencontainers.image.url=$CI_PROJECT_URL" - --label "org.opencontainers.image.created=$CI_JOB_STARTED_AT" - --label "org.opencontainers.image.revision=$CI_COMMIT_SHA" - --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" - --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME --build-arg "BASE_IMAGE=ubuntu:20.04" + --build-arg BZL_CONFIGS= + --build-arg KEEP_SRC_OTB=true . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test -- GitLab From f7035e02acc8fdd9dc89c58cc01988d5be733302 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 21:12:22 +0200 Subject: [PATCH 035/154] WIP: use godzilla runner --- .gitlab-ci.yml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index b4e2cb66..331d96e9 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -38,7 +38,15 @@ Build the CI docker image: - ls $PWD - > docker build - --build-arg "BASE_IMAGE=ubuntu:20.04" + --pull + --cache-from $CI_REGISTRY_IMAGE:latest + --label "org.opencontainers.image.title=$CI_PROJECT_TITLE" + --label "org.opencontainers.image.url=$CI_PROJECT_URL" + --label "org.opencontainers.image.created=$CI_JOB_STARTED_AT" + --label "org.opencontainers.image.revision=$CI_COMMIT_SHA" + --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" + --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME + --build-arg "BASE_IMAGE=gitlab-registry.irstea.fr/remi.cresson/otbtf/otbtf3.0:cpu-basic-dev" --build-arg BZL_CONFIGS= --build-arg KEEP_SRC_OTB=true . -- GitLab From 1cd2ee72036f9b5320ee9416d0950bb671c8bac5 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 21:14:30 +0200 Subject: [PATCH 036/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 331d96e9..98b2a596 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -47,8 +47,6 @@ Build the CI docker image: --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME --build-arg "BASE_IMAGE=gitlab-registry.irstea.fr/remi.cresson/otbtf/otbtf3.0:cpu-basic-dev" - --build-arg BZL_CONFIGS= - --build-arg KEEP_SRC_OTB=true . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test -- GitLab From d563001a8f64cf2632b501b9a3fbdca3ce743c3c Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 21:16:21 +0200 Subject: [PATCH 037/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 98b2a596..0814e42c 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -36,6 +36,8 @@ Build the CI docker image: script: - docker info - ls $PWD + - docker run ubuntu:20.04 bash -c "echo 'Hello'" + - docker build . - > docker build --pull -- GitLab From daf4fd1c0eb4e2c8217f365c3d7f0a1b8b0d2a15 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 21:19:16 +0200 Subject: [PATCH 038/154] WIP: use godzilla runner --- .gitlab-ci.yml | 1 - Dockerfile | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 0814e42c..0fe610f9 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -37,7 +37,6 @@ Build the CI docker image: - docker info - ls $PWD - docker run ubuntu:20.04 bash -c "echo 'Hello'" - - docker build . - > docker build --pull diff --git a/Dockerfile b/Dockerfile index 28ee7875..0c0d401a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ ARG BASE_IMG # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env -FROM $BASE_IMG AS otbtf-base +FROM ${BASE_IMG} AS otbtf-base WORKDIR /tmp ### System packages -- GitLab From e88d916b420a2617efd626399d4dd24513253623 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 21:23:43 +0200 Subject: [PATCH 039/154] WIP: use godzilla runner --- .gitlab-ci.yml | 3 +-- Dockerfile | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 0fe610f9..53307fb7 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -36,7 +36,6 @@ Build the CI docker image: script: - docker info - ls $PWD - - docker run ubuntu:20.04 bash -c "echo 'Hello'" - > docker build --pull @@ -47,7 +46,7 @@ Build the CI docker image: --label "org.opencontainers.image.revision=$CI_COMMIT_SHA" --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - --build-arg "BASE_IMAGE=gitlab-registry.irstea.fr/remi.cresson/otbtf/otbtf3.0:cpu-basic-dev" + --build-arg BASE_IMAGE="ubuntu:20.04" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test diff --git a/Dockerfile b/Dockerfile index 0c0d401a..fdb48b1d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,8 @@ ARG BASE_IMG # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env -FROM ${BASE_IMG} AS otbtf-base +RUN echo $BASE_IMG +FROM $BASE_IMG AS otbtf-base WORKDIR /tmp ### System packages -- GitLab From f0533d241f9936fc6da9adb17b3eb8dd18fa1a2a Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 21:32:42 +0200 Subject: [PATCH 040/154] WIP: use godzilla runner --- .gitlab-ci.yml | 1 + Dockerfile | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 53307fb7..a8aea358 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -36,6 +36,7 @@ Build the CI docker image: script: - docker info - ls $PWD + - docker build --build-arg BASE_IMAGE=ubuntu:20.04 . - > docker build --pull diff --git a/Dockerfile b/Dockerfile index fdb48b1d..28ee7875 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,6 @@ ARG BASE_IMG # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env -RUN echo $BASE_IMG FROM $BASE_IMG AS otbtf-base WORKDIR /tmp -- GitLab From 4d66c9648a57a1f148219829ebd6db8ab43e7478 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 21:54:06 +0200 Subject: [PATCH 041/154] WIP: use godzilla runner --- .gitlab-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a8aea358..84b47b59 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -36,9 +36,9 @@ Build the CI docker image: script: - docker info - ls $PWD - - docker build --build-arg BASE_IMAGE=ubuntu:20.04 . + - echo docker build --build-arg BASE_IMAGE=ubuntu:20.04 . - > - docker build + echo docker build --pull --cache-from $CI_REGISTRY_IMAGE:latest --label "org.opencontainers.image.title=$CI_PROJECT_TITLE" -- GitLab From d5d717b516e130dca409a8a60cdacc3242572013 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 21:57:10 +0200 Subject: [PATCH 042/154] WIP: use godzilla runner --- .gitlab-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 84b47b59..3833e72c 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -36,9 +36,9 @@ Build the CI docker image: script: - docker info - ls $PWD - - echo docker build --build-arg BASE_IMAGE=ubuntu:20.04 . + - docker build --build-arg BASE_IMAGE=abbcdef . - > - echo docker build + docker build --pull --cache-from $CI_REGISTRY_IMAGE:latest --label "org.opencontainers.image.title=$CI_PROJECT_TITLE" -- GitLab From d610853dd5691037d5d3e2a95ae6308ba4b397a6 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 22:11:44 +0200 Subject: [PATCH 043/154] WIP: use godzilla runner --- Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 28ee7875..489ef7cc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,8 +4,9 @@ ARG BASE_IMG # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env -FROM $BASE_IMG AS otbtf-base +FROM ubuntu:20.04 AS otbtf-base WORKDIR /tmp +RUN echo $BASE_IMG ### System packages COPY tools/docker/build-deps-*.txt ./ -- GitLab From 67ca132f9c03d9c75384e6027a2931175cfbac23 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 22:17:43 +0200 Subject: [PATCH 044/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- Dockerfile | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3833e72c..3b0deb84 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -36,7 +36,7 @@ Build the CI docker image: script: - docker info - ls $PWD - - docker build --build-arg BASE_IMAGE=abbcdef . + - docker build --build-arg "BASE_IMAGE=abbcdef" . - > docker build --pull diff --git a/Dockerfile b/Dockerfile index 489ef7cc..28ee7875 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,9 +4,8 @@ ARG BASE_IMG # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env -FROM ubuntu:20.04 AS otbtf-base +FROM $BASE_IMG AS otbtf-base WORKDIR /tmp -RUN echo $BASE_IMG ### System packages COPY tools/docker/build-deps-*.txt ./ -- GitLab From 840b154c236274f64a6a323b32c35ebf40bf0c32 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 22:28:35 +0200 Subject: [PATCH 045/154] WIP: use godzilla runner --- Dockerfile | 3 --- 1 file changed, 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 28ee7875..c17f0d6d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,6 @@ ##### Configurable Dockerfile with multi-stage build - Author: Vincent Delbar ## Mandatory ARG BASE_IMG - -# ---------------------------------------------------------------------------- -# Init base stage - will be cloned as intermediate build env FROM $BASE_IMG AS otbtf-base WORKDIR /tmp -- GitLab From 9a89a612e35f4c5fe71d291522e0a92610440f47 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 22:32:29 +0200 Subject: [PATCH 046/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- Dockerfile | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3b0deb84..42483646 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -36,7 +36,7 @@ Build the CI docker image: script: - docker info - ls $PWD - - docker build --build-arg "BASE_IMAGE=abbcdef" . + - docker build --build-arg "BASE_IMAGE=ubuntu" . - > docker build --pull diff --git a/Dockerfile b/Dockerfile index c17f0d6d..28ee7875 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,9 @@ ##### Configurable Dockerfile with multi-stage build - Author: Vincent Delbar ## Mandatory ARG BASE_IMG + +# ---------------------------------------------------------------------------- +# Init base stage - will be cloned as intermediate build env FROM $BASE_IMG AS otbtf-base WORKDIR /tmp -- GitLab From 47901c22629473302358c15b75f61ad766160d4e Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 22:51:28 +0200 Subject: [PATCH 047/154] WIP: use godzilla runner --- Dockerfile | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 28ee7875..55ad17a4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,8 +4,11 @@ ARG BASE_IMG # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env -FROM $BASE_IMG AS otbtf-base +FROM ubuntu:20.04 AS otbtf-base + +#FROM $BASE_IMG AS otbtf-base WORKDIR /tmp +RUN echo $BASE_IMG > base_img ### System packages COPY tools/docker/build-deps-*.txt ./ -- GitLab From 6f13c4688d74e2aa8cfb41ed0b5bcc6d6fa640f1 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 23:02:28 +0200 Subject: [PATCH 048/154] WIP: use godzilla runner --- .gitlab-ci.yml | 4 +++- Dockerfile | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 42483646..0bdbeae1 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -47,7 +47,9 @@ Build the CI docker image: --label "org.opencontainers.image.revision=$CI_COMMIT_SHA" --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - --build-arg BASE_IMAGE="ubuntu:20.04" + --build-arg OTBTESTS=true + --build-arg KEEP_SRC_OTB=true + --build-arg BZL_CONFIGS="" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test diff --git a/Dockerfile b/Dockerfile index 55ad17a4..01faf82c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,14 +1,11 @@ ##### Configurable Dockerfile with multi-stage build - Author: Vincent Delbar ## Mandatory -ARG BASE_IMG +ARG BASE_IMG=ubuntu:20.04 # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env -FROM ubuntu:20.04 AS otbtf-base - -#FROM $BASE_IMG AS otbtf-base +FROM $BASE_IMG AS otbtf-base WORKDIR /tmp -RUN echo $BASE_IMG > base_img ### System packages COPY tools/docker/build-deps-*.txt ./ @@ -88,6 +85,7 @@ RUN git clone --single-branch -b $TF https://github.com/tensorflow/tensorflow.gi ### OTB ARG GUI=false ARG OTB=7.4.0 +ARG OTBTESTS=false RUN mkdir /src/otb WORKDIR /src/otb @@ -105,6 +103,8 @@ RUN apt-get update -y \ sed -i -r "s/-DOTB_USE_(QT|OPENGL|GL[UFE][WT])=OFF/-DOTB_USE_\1=ON/" ../build-flags-otb.txt; fi \ # Possible ENH: superbuild-all-dependencies switch, with separated build-deps-minimal.txt and build-deps-otbcli.txt) #&& if $OTB_SUPERBUILD_ALL; then sed -i -r "s/-DUSE_SYSTEM_([A-Z0-9]*)=ON/-DUSE_SYSTEM_\1=OFF/ " ../build-flags-otb.txt; fi \ + && if $OTBTESTS; then \ + echo "-DBUILD_TESTING=ON" >> "../build-flags-otb.txt" \ && OTB_FLAGS=$(cat "../build-flags-otb.txt") \ && cmake ../otb/SuperBuild -DCMAKE_INSTALL_PREFIX=/opt/otbtf $OTB_FLAGS \ && make -j $(python -c "import os; print(round( os.cpu_count() * $CPU_RATIO ))") -- GitLab From b001e3a3dc096c7f6e330dca5567ead0ca1625b4 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 23:03:10 +0200 Subject: [PATCH 049/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 0bdbeae1..eb55fe12 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -35,8 +35,6 @@ Build the CI docker image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY script: - docker info - - ls $PWD - - docker build --build-arg "BASE_IMAGE=ubuntu" . - > docker build --pull -- GitLab From 29aee91bef11aa592f522083d4c95b0d1fd9982e Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 09:13:11 +0200 Subject: [PATCH 050/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- Dockerfile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index eb55fe12..3f3576d5 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -22,7 +22,7 @@ Build the CI docker image: stage: Build allow_failure: false tags: [godzilla] - image: docker/compose:1.29.2 + image: docker/compose:stable services: - name: docker:17.06.0-ce-dind alias: docker diff --git a/Dockerfile b/Dockerfile index 01faf82c..d701a81a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ ##### Configurable Dockerfile with multi-stage build - Author: Vincent Delbar ## Mandatory -ARG BASE_IMG=ubuntu:20.04 +ARG BASE_IMG # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env -- GitLab From 5230fe25519facb6d7312c7fad077fbe9e9764a4 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 09:20:14 +0200 Subject: [PATCH 051/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3f3576d5..9189d042 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -22,7 +22,7 @@ Build the CI docker image: stage: Build allow_failure: false tags: [godzilla] - image: docker/compose:stable + image: docker/compose:18.09.06 services: - name: docker:17.06.0-ce-dind alias: docker -- GitLab From d382e440aa02ef022bfe08b59a65fe003c0e383d Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 09:21:43 +0200 Subject: [PATCH 052/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 9189d042..1ac4d143 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -22,7 +22,7 @@ Build the CI docker image: stage: Build allow_failure: false tags: [godzilla] - image: docker/compose:18.09.06 + image: docker/compose:latest services: - name: docker:17.06.0-ce-dind alias: docker -- GitLab From a9846a4ae1d59ae7d4f95d316caf6d078c911b67 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 09:26:15 +0200 Subject: [PATCH 053/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 1ac4d143..1af7ec9f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -25,7 +25,7 @@ Build the CI docker image: image: docker/compose:latest services: - name: docker:17.06.0-ce-dind - alias: docker +# alias: docker before_script: # docker login asks for the password to be passed through stdin for security -- GitLab From c196a472f3f1a8008c00b55db36aa80fc84479ea Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 09:29:34 +0200 Subject: [PATCH 054/154] WIP: use godzilla runner --- .gitlab-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 1af7ec9f..8c9f1031 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -18,13 +18,13 @@ stages: - Test - Applications Test -Build the CI docker image: +cpu-basic-test image: stage: Build allow_failure: false tags: [godzilla] image: docker/compose:latest services: - - name: docker:17.06.0-ce-dind + - name: docker:dind # alias: docker before_script: -- GitLab From d974d071ab903ff81a7295c21fa9a92c72dcc14c Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 09:33:46 +0200 Subject: [PATCH 055/154] WIP: use godzilla runner --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index d701a81a..fb31bc52 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ ##### Configurable Dockerfile with multi-stage build - Author: Vincent Delbar ## Mandatory -ARG BASE_IMG +ARG BASE_IMG=tototo123 # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env -- GitLab From 6bc245f83ce68d400188a94d1b5adf42b94de5aa Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 09:36:57 +0200 Subject: [PATCH 056/154] WIP: use godzilla runner --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index fb31bc52..e0c90f86 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,9 @@ ##### Configurable Dockerfile with multi-stage build - Author: Vincent Delbar -## Mandatory -ARG BASE_IMG=tototo123 # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env +## Mandatory +ARG BASE_IMG=tototo123 FROM $BASE_IMG AS otbtf-base WORKDIR /tmp -- GitLab From d8d74a3da6f69691f8de1fa021919a938ec5ff2b Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 09:38:33 +0200 Subject: [PATCH 057/154] WIP: use godzilla runner --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 8c9f1031..8aaabb81 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -48,6 +48,7 @@ cpu-basic-test image: --build-arg OTBTESTS=true --build-arg KEEP_SRC_OTB=true --build-arg BZL_CONFIGS="" + --build-arg BASE_IMAGE="ubuntu:20.04" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test -- GitLab From 01c3e2d59695ee720281ddbbcec5f4a21a219a1c Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 09:42:15 +0200 Subject: [PATCH 058/154] WIP: use godzilla runner --- .gitlab-ci.yml | 1 + Dockerfile | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 8aaabb81..1793c069 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -35,6 +35,7 @@ cpu-basic-test image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY script: - docker info + - docker build --build-arg BASE_IMAGE="ubuntu:20.04" . - > docker build --pull diff --git a/Dockerfile b/Dockerfile index e0c90f86..d701a81a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,9 @@ ##### Configurable Dockerfile with multi-stage build - Author: Vincent Delbar +## Mandatory +ARG BASE_IMG # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env -## Mandatory -ARG BASE_IMG=tototo123 FROM $BASE_IMG AS otbtf-base WORKDIR /tmp -- GitLab From 1c7927b47154df21960e8aee2379a992758cb72e Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 09:47:49 +0200 Subject: [PATCH 059/154] WIP: use godzilla runner --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index d701a81a..6cc56ff1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ ARG BASE_IMG # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env -FROM $BASE_IMG AS otbtf-base +FROM ${BASE_IMG} AS otbtf-base WORKDIR /tmp ### System packages -- GitLab From 066ceeb6ce163ed4d5f925ac391a820b6bd2ceeb Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 09:50:55 +0200 Subject: [PATCH 060/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 1793c069..7e52c090 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -35,6 +35,8 @@ cpu-basic-test image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY script: - docker info + - docker version + - docker-compose version - docker build --build-arg BASE_IMAGE="ubuntu:20.04" . - > docker build -- GitLab From 7b4b3a8096c3d5a5aebed416c4bbe815da6bff9b Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 09:53:57 +0200 Subject: [PATCH 061/154] WIP: use godzilla runner --- .gitlab-ci.yml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 7e52c090..ed44e51c 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -23,20 +23,18 @@ cpu-basic-test image: allow_failure: false tags: [godzilla] image: docker/compose:latest - services: - - name: docker:dind -# alias: docker +# services: +# - name: docker:dind before_script: # docker login asks for the password to be passed through stdin for security # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab # https://docs.gitlab.com/ce/ci/variables/predefined_variables.html - - docker info - - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY - script: - docker info - docker version - docker-compose version + - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY + script: - docker build --build-arg BASE_IMAGE="ubuntu:20.04" . - > docker build -- GitLab From fdee72e249bb0a35551bd1517be6799420167f8c Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:03:45 +0200 Subject: [PATCH 062/154] WIP: use godzilla runner --- .gitlab-ci.yml | 5 ++--- Dockerfile | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index ed44e51c..c164e835 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -23,9 +23,8 @@ cpu-basic-test image: allow_failure: false tags: [godzilla] image: docker/compose:latest -# services: -# - name: docker:dind - + services: + - name: docker:dind before_script: # docker login asks for the password to be passed through stdin for security # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab diff --git a/Dockerfile b/Dockerfile index 6cc56ff1..62b0d3cf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ ARG BASE_IMG # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env -FROM ${BASE_IMG} AS otbtf-base +FROM BASE_IMG AS otbtf-base WORKDIR /tmp ### System packages -- GitLab From a3a8cad9079b004027502d46f54efd4df8078f13 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:07:04 +0200 Subject: [PATCH 063/154] WIP: use godzilla runner --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 62b0d3cf..d701a81a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ ARG BASE_IMG # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env -FROM BASE_IMG AS otbtf-base +FROM $BASE_IMG AS otbtf-base WORKDIR /tmp ### System packages -- GitLab From 796d4d493f8ff610b7259c7a580ce8938fd152b0 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:09:54 +0200 Subject: [PATCH 064/154] WIP: use godzilla runner --- Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/Dockerfile b/Dockerfile index d701a81a..69a3d352 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,6 @@ ##### Configurable Dockerfile with multi-stage build - Author: Vincent Delbar ## Mandatory +ARG BASE_IMG=tototot123 ARG BASE_IMG # ---------------------------------------------------------------------------- -- GitLab From f0550368e25b55544fd2cf62a3eff6fb50419e93 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:14:00 +0200 Subject: [PATCH 065/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c164e835..cdef14c4 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -25,6 +25,8 @@ cpu-basic-test image: image: docker/compose:latest services: - name: docker:dind + variables: + - BASE_IMAGE: "ubuntu:20.04" before_script: # docker login asks for the password to be passed through stdin for security # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab -- GitLab From ed84d3022097cdd8de4291b61aaaf9ea8b0e08cd Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:14:26 +0200 Subject: [PATCH 066/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index cdef14c4..141fcb3a 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -26,7 +26,7 @@ cpu-basic-test image: services: - name: docker:dind variables: - - BASE_IMAGE: "ubuntu:20.04" + - BASE_IMAGE: ubuntu:20.04 before_script: # docker login asks for the password to be passed through stdin for security # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab -- GitLab From f4e1c18a1d3abd6a02aba54ab0c59b1be16fedb1 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:15:00 +0200 Subject: [PATCH 067/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 141fcb3a..1aeed61a 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -26,7 +26,7 @@ cpu-basic-test image: services: - name: docker:dind variables: - - BASE_IMAGE: ubuntu:20.04 + BASE_IMAGE: ubuntu:20.04 before_script: # docker login asks for the password to be passed through stdin for security # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab -- GitLab From 0554581c8c06e9c407d23ff2ed0113881e334855 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:16:15 +0200 Subject: [PATCH 068/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 1aeed61a..36f2a88d 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -36,7 +36,7 @@ cpu-basic-test image: - docker-compose version - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY script: - - docker build --build-arg BASE_IMAGE="ubuntu:20.04" . + - docker build . - > docker build --pull -- GitLab From bdd3a78b86f51da73d4b41d383ca4a97f44434c2 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:21:27 +0200 Subject: [PATCH 069/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 36f2a88d..73155c4d 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -36,7 +36,7 @@ cpu-basic-test image: - docker-compose version - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY script: - - docker build . + - bash -c "docker build --build-arg BASE_IMAGE=ubuntu:20.04 ." - > docker build --pull -- GitLab From 29eac4b11dfe39c44aa672cea35aae28314325df Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:24:22 +0200 Subject: [PATCH 070/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 73155c4d..36f2a88d 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -36,7 +36,7 @@ cpu-basic-test image: - docker-compose version - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY script: - - bash -c "docker build --build-arg BASE_IMAGE=ubuntu:20.04 ." + - docker build . - > docker build --pull -- GitLab From 8470424d5948e0cc7b4346f2f7cfa8a6bd4f43a8 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:25:41 +0200 Subject: [PATCH 071/154] WIP: use godzilla runner --- Dockerfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 69a3d352..08cdfb75 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,6 @@ ##### Configurable Dockerfile with multi-stage build - Author: Vincent Delbar ## Mandatory ARG BASE_IMG=tototot123 -ARG BASE_IMG - # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env FROM $BASE_IMG AS otbtf-base -- GitLab From 995159be228ed7824c076ad773a7656aa3681fc1 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:26:51 +0200 Subject: [PATCH 072/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 36f2a88d..c358a859 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -36,7 +36,7 @@ cpu-basic-test image: - docker-compose version - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY script: - - docker build . + - docker build --build-arg BASE_IMAGE=ubuntu:20.04 . - > docker build --pull -- GitLab From e54119e1436f64bad86d605fb5a55ff4e2276ee5 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:32:41 +0200 Subject: [PATCH 073/154] WIP: use godzilla runner --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c358a859..85bb7ade 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -25,6 +25,7 @@ cpu-basic-test image: image: docker/compose:latest services: - name: docker:dind + command: [dockerd, '-H', 'tcp://0.0.0.0:2375'] variables: BASE_IMAGE: ubuntu:20.04 before_script: -- GitLab From 85e000cef323997d1c0fc00aadf4107725c1ce1d Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:34:27 +0200 Subject: [PATCH 074/154] WIP: use godzilla runner --- .gitlab-ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 85bb7ade..ebbbefea 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -37,7 +37,8 @@ cpu-basic-test image: - docker-compose version - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY script: - - docker build --build-arg BASE_IMAGE=ubuntu:20.04 . + - echo docker build --build-arg BASE_IMAGE=$BASE_IMAGE . + - docker build --build-arg BASE_IMAGE=$BASE_IMAGE . - > docker build --pull -- GitLab From a785b2d5a87e3240f61b48a2240794e8328b3a78 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:43:35 +0200 Subject: [PATCH 075/154] WIP: use godzilla runner --- .gitlab-ci.yml | 18 +----------------- tools/docker/docker_build_cpu-dev-tests.sh | 13 +++++++++++++ 2 files changed, 14 insertions(+), 17 deletions(-) create mode 100644 tools/docker/docker_build_cpu-dev-tests.sh diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index ebbbefea..42e960a9 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -37,23 +37,7 @@ cpu-basic-test image: - docker-compose version - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY script: - - echo docker build --build-arg BASE_IMAGE=$BASE_IMAGE . - - docker build --build-arg BASE_IMAGE=$BASE_IMAGE . - - > - docker build - --pull - --cache-from $CI_REGISTRY_IMAGE:latest - --label "org.opencontainers.image.title=$CI_PROJECT_TITLE" - --label "org.opencontainers.image.url=$CI_PROJECT_URL" - --label "org.opencontainers.image.created=$CI_JOB_STARTED_AT" - --label "org.opencontainers.image.revision=$CI_COMMIT_SHA" - --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" - --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - --build-arg OTBTESTS=true - --build-arg KEEP_SRC_OTB=true - --build-arg BZL_CONFIGS="" - --build-arg BASE_IMAGE="ubuntu:20.04" - . + - tools/docker/./docker_build_cpu-dev-tests.sh - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test - docker push $CI_REGISTRY_IMAGE:cpu-basic-test diff --git a/tools/docker/docker_build_cpu-dev-tests.sh b/tools/docker/docker_build_cpu-dev-tests.sh new file mode 100644 index 00000000..c4f9ae2f --- /dev/null +++ b/tools/docker/docker_build_cpu-dev-tests.sh @@ -0,0 +1,13 @@ +docker build \ +--pull \ +--cache-from $CI_REGISTRY_IMAGE:latest \ +--label "org.opencontainers.image.title=$CI_PROJECT_TITLE" \ +--label "org.opencontainers.image.url=$CI_PROJECT_URL" \ +--label "org.opencontainers.image.created=$CI_JOB_STARTED_AT" \ +--label "org.opencontainers.image.revision=$CI_COMMIT_SHA" \ +--label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" \ +--tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME \ +--build-arg OTBTESTS=true \ +--build-arg BZL_CONFIGS="" \ +--build-arg BASE_IMAGE="ubuntu:20.04" \ +. \ No newline at end of file -- GitLab From 1008a74e361396ee53ffd318dc4d5a989ce6fa3c Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:45:46 +0200 Subject: [PATCH 076/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 42e960a9..7544c918 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -37,7 +37,7 @@ cpu-basic-test image: - docker-compose version - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY script: - - tools/docker/./docker_build_cpu-dev-tests.sh + - sh tools/docker/docker_build_cpu-dev-tests.sh - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test - docker push $CI_REGISTRY_IMAGE:cpu-basic-test -- GitLab From 8d7132776cbf795f23793dedb680af982d548eca Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:48:32 +0200 Subject: [PATCH 077/154] WIP: use godzilla runner --- .gitlab-ci.yml | 3 --- tools/docker/docker_build_cpu-dev-tests.sh | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 7544c918..7febdc25 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -25,9 +25,6 @@ cpu-basic-test image: image: docker/compose:latest services: - name: docker:dind - command: [dockerd, '-H', 'tcp://0.0.0.0:2375'] - variables: - BASE_IMAGE: ubuntu:20.04 before_script: # docker login asks for the password to be passed through stdin for security # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab diff --git a/tools/docker/docker_build_cpu-dev-tests.sh b/tools/docker/docker_build_cpu-dev-tests.sh index c4f9ae2f..1b82f51b 100644 --- a/tools/docker/docker_build_cpu-dev-tests.sh +++ b/tools/docker/docker_build_cpu-dev-tests.sh @@ -1,4 +1,4 @@ -docker build \ +echo docker build \ --pull \ --cache-from $CI_REGISTRY_IMAGE:latest \ --label "org.opencontainers.image.title=$CI_PROJECT_TITLE" \ -- GitLab From 44b6f567ece3e5dd612d0e8f661eb6c2d9657bf7 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:53:12 +0200 Subject: [PATCH 078/154] WIP: use godzilla runner --- tools/docker/docker_build_cpu-dev-tests.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/docker/docker_build_cpu-dev-tests.sh b/tools/docker/docker_build_cpu-dev-tests.sh index 1b82f51b..1cf22847 100644 --- a/tools/docker/docker_build_cpu-dev-tests.sh +++ b/tools/docker/docker_build_cpu-dev-tests.sh @@ -1,4 +1,5 @@ -echo docker build \ +docker build --help +docker build \ --pull \ --cache-from $CI_REGISTRY_IMAGE:latest \ --label "org.opencontainers.image.title=$CI_PROJECT_TITLE" \ -- GitLab From 7e1d50a5896b48b1c6aa5c40921d59e3a1fd5169 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:55:05 +0200 Subject: [PATCH 079/154] WIP: use godzilla runner --- tools/docker/docker_build_cpu-dev-tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/docker/docker_build_cpu-dev-tests.sh b/tools/docker/docker_build_cpu-dev-tests.sh index 1cf22847..623d076d 100644 --- a/tools/docker/docker_build_cpu-dev-tests.sh +++ b/tools/docker/docker_build_cpu-dev-tests.sh @@ -1,5 +1,6 @@ docker build --help docker build \ +--no-cache \ --pull \ --cache-from $CI_REGISTRY_IMAGE:latest \ --label "org.opencontainers.image.title=$CI_PROJECT_TITLE" \ -- GitLab From b433a51dcc59d82f859fd1a7df83023cfa31d4c1 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 10:56:45 +0200 Subject: [PATCH 080/154] WIP: use godzilla runner --- tools/docker/docker_build_cpu-dev-tests.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tools/docker/docker_build_cpu-dev-tests.sh b/tools/docker/docker_build_cpu-dev-tests.sh index 623d076d..b5501a32 100644 --- a/tools/docker/docker_build_cpu-dev-tests.sh +++ b/tools/docker/docker_build_cpu-dev-tests.sh @@ -1,5 +1,5 @@ docker build --help -docker build \ +docker build . \ --no-cache \ --pull \ --cache-from $CI_REGISTRY_IMAGE:latest \ @@ -11,5 +11,4 @@ docker build \ --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME \ --build-arg OTBTESTS=true \ --build-arg BZL_CONFIGS="" \ ---build-arg BASE_IMAGE="ubuntu:20.04" \ -. \ No newline at end of file +--build-arg BASE_IMAGE="ubuntu:20.04" \ No newline at end of file -- GitLab From c7a50fea020b245b08028ec24bf83c3d7a069b50 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 12:16:29 +0200 Subject: [PATCH 081/154] WIP: use godzilla runner --- .gitlab-ci.yml | 21 ++++++++++++++------- Dockerfile | 6 ++---- tools/docker/docker_build_cpu-dev-tests.sh | 14 -------------- 3 files changed, 16 insertions(+), 25 deletions(-) delete mode 100644 tools/docker/docker_build_cpu-dev-tests.sh diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 7febdc25..e74c91a6 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -26,15 +26,22 @@ cpu-basic-test image: services: - name: docker:dind before_script: - # docker login asks for the password to be passed through stdin for security - # we use $CI_REGISTRY_PASSWORD here which is a special variable provided by GitLab - # https://docs.gitlab.com/ce/ci/variables/predefined_variables.html - - docker info - - docker version - - docker-compose version - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY script: - - sh tools/docker/docker_build_cpu-dev-tests.sh + - > + docker build + --pull + --cache-from $CI_REGISTRY_IMAGE:latest + --label "org.opencontainers.image.title=$CI_PROJECT_TITLE" + --label "org.opencontainers.image.url=$CI_PROJECT_URL" + --label "org.opencontainers.image.created=$CI_JOB_STARTED_AT" + --label "org.opencontainers.image.revision=$CI_COMMIT_SHA" + --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" + --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME + --build-arg OTBTESTS="true" + --build-arg BZL_CONFIGS="" + --build-arg BASE_IMG="ubuntu:20.04" + . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test - docker push $CI_REGISTRY_IMAGE:cpu-basic-test diff --git a/Dockerfile b/Dockerfile index 08cdfb75..28ee7875 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,7 @@ ##### Configurable Dockerfile with multi-stage build - Author: Vincent Delbar ## Mandatory -ARG BASE_IMG=tototot123 +ARG BASE_IMG + # ---------------------------------------------------------------------------- # Init base stage - will be cloned as intermediate build env FROM $BASE_IMG AS otbtf-base @@ -84,7 +85,6 @@ RUN git clone --single-branch -b $TF https://github.com/tensorflow/tensorflow.gi ### OTB ARG GUI=false ARG OTB=7.4.0 -ARG OTBTESTS=false RUN mkdir /src/otb WORKDIR /src/otb @@ -102,8 +102,6 @@ RUN apt-get update -y \ sed -i -r "s/-DOTB_USE_(QT|OPENGL|GL[UFE][WT])=OFF/-DOTB_USE_\1=ON/" ../build-flags-otb.txt; fi \ # Possible ENH: superbuild-all-dependencies switch, with separated build-deps-minimal.txt and build-deps-otbcli.txt) #&& if $OTB_SUPERBUILD_ALL; then sed -i -r "s/-DUSE_SYSTEM_([A-Z0-9]*)=ON/-DUSE_SYSTEM_\1=OFF/ " ../build-flags-otb.txt; fi \ - && if $OTBTESTS; then \ - echo "-DBUILD_TESTING=ON" >> "../build-flags-otb.txt" \ && OTB_FLAGS=$(cat "../build-flags-otb.txt") \ && cmake ../otb/SuperBuild -DCMAKE_INSTALL_PREFIX=/opt/otbtf $OTB_FLAGS \ && make -j $(python -c "import os; print(round( os.cpu_count() * $CPU_RATIO ))") diff --git a/tools/docker/docker_build_cpu-dev-tests.sh b/tools/docker/docker_build_cpu-dev-tests.sh deleted file mode 100644 index b5501a32..00000000 --- a/tools/docker/docker_build_cpu-dev-tests.sh +++ /dev/null @@ -1,14 +0,0 @@ -docker build --help -docker build . \ ---no-cache \ ---pull \ ---cache-from $CI_REGISTRY_IMAGE:latest \ ---label "org.opencontainers.image.title=$CI_PROJECT_TITLE" \ ---label "org.opencontainers.image.url=$CI_PROJECT_URL" \ ---label "org.opencontainers.image.created=$CI_JOB_STARTED_AT" \ ---label "org.opencontainers.image.revision=$CI_COMMIT_SHA" \ ---label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" \ ---tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME \ ---build-arg OTBTESTS=true \ ---build-arg BZL_CONFIGS="" \ ---build-arg BASE_IMAGE="ubuntu:20.04" \ No newline at end of file -- GitLab From c34366b1c4de05657299bc80bc513792d3382b53 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 13:54:39 +0200 Subject: [PATCH 082/154] WIP: use godzilla runner --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index e74c91a6..393c0bdd 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -27,6 +27,7 @@ cpu-basic-test image: - name: docker:dind before_script: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY + timeout: 10 hours script: - > docker build -- GitLab From d223155ef2b4b616fe9d4eaddc2d1d760f61fdba Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 14:25:40 +0200 Subject: [PATCH 083/154] FIX: list indices must be integers --- python/otbtf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/otbtf.py b/python/otbtf.py index cbb96e55..83cb4fe5 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -225,7 +225,7 @@ class PatchesImagesReader(PatchesReaderBase): # if use_streaming is False, we store in memory all patches images if not self.use_streaming: - self.patches_buffer = {src_key: np.concatenate([read_as_np_arr(ds) for ds in src_ds[src_key]], axis=0) for + self.patches_buffer = {src_key: np.concatenate([read_as_np_arr(ds) for ds in src_ds], axis=0) for src_key, src_ds in self.gdal_ds.items()} def _get_ds_and_offset_from_index(self, index): -- GitLab From 1a2a42d0eb5ff61ff7266d702d21ef9578b27d4e Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 14:27:34 +0200 Subject: [PATCH 084/154] FIX: static function --- python/otbtf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/otbtf.py b/python/otbtf.py index 83cb4fe5..75d188ae 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -581,7 +581,7 @@ class TFRecords: self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None @staticmethod - def _bytes_feature(self, value): + def _bytes_feature(value): """ Convert a value to a type compatible with tf.train.Example. :param value: value -- GitLab From c3717f3b86991d160af74ee09b7bf55a05371efb Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Wed, 20 Apr 2022 12:08:33 +0200 Subject: [PATCH 085/154] (Cherrypick from 14) generalize cropping target to a preprocessing function --- python/otbtf.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 75d188ae..c180a7c4 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -672,11 +672,15 @@ class TFRecords: @staticmethod def parse_tfrecord(example, features_types, target_keys): + def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs): """ Parse example object to sample dict. :param example: Example object to parse :param features_types: List of types for each feature :param target_keys: list of keys of the targets + :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns + a tuple (input_preprocessed, target_preprocessed) + :param kwargs: some keywords arguments for preprocessing_fn """ read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types} example_parsed = tf.io.parse_single_example(example, read_features) @@ -688,9 +692,13 @@ class TFRecords: input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys} target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} + if preprocessing_fn: + input_parsed, target_parsed = preprocessing_fn(input_parsed, target_parsed, **kwargs) + return input_parsed, target_parsed - def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None): + def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None, + preprocessing_fn=None, **kwargs): """ Read all tfrecord files matching with pattern and convert data to tensorflow dataset. :param batch_size: Size of tensorflow batch @@ -702,12 +710,16 @@ class TFRecords: False is advisable when evaluating metrics so that all samples are used :param shuffle_buffer_size: if None, shuffle is not used. Else, blocks of shuffle_buffer_size elements are shuffled using uniform random. + :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns + a tuple (input_preprocessed, target_preprocessed) + :param kwargs: some keywords arguments for preprocessing_fn """ options = tf.data.Options() if shuffle_buffer_size: options.experimental_deterministic = False # disable order, increase speed options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO # for multiworker - parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys, + preprocessing_fn=preprocessing_fn, **kwargs) # TODO: to be investigated : # 1/ num_parallel_reads useful ? I/O bottleneck of not ? -- GitLab From 02ec7da70e69aa754f3fd88cc316ef951642d9d4 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 15:46:58 +0200 Subject: [PATCH 086/154] WIP: use godzilla runner --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 393c0bdd..1f47299e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -40,6 +40,7 @@ cpu-basic-test image: --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME --build-arg OTBTESTS="true" + --build-arg KEEP_OTB_SRC="true" --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" . -- GitLab From ce9c972ff2553102f38ee909de486cf6185bcad4 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 15:49:32 +0200 Subject: [PATCH 087/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 1f47299e..ea2ed507 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -32,7 +32,7 @@ cpu-basic-test image: - > docker build --pull - --cache-from $CI_REGISTRY_IMAGE:latest + --cache-from $CI_REGISTRY_IMAGE:cpu-basic-test --label "org.opencontainers.image.title=$CI_PROJECT_TITLE" --label "org.opencontainers.image.url=$CI_PROJECT_URL" --label "org.opencontainers.image.created=$CI_JOB_STARTED_AT" -- GitLab From 161e33e8cdc9f7024cb5c081f0d0561fdc300b30 Mon Sep 17 00:00:00 2001 From: Vincent Delbar <vincent.delbar@latelescop.fr> Date: Thu, 21 Apr 2022 15:55:49 +0200 Subject: [PATCH 088/154] FIX: indented bloc --- python/otbtf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/otbtf.py b/python/otbtf.py index c180a7c4..b28a1cc4 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -671,7 +671,6 @@ class TFRecords: self.save(output_shapes, self.output_shape_file) @staticmethod - def parse_tfrecord(example, features_types, target_keys): def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs): """ Parse example object to sample dict. -- GitLab From 3e9b5c3cfced4ca27bab3b32478b2d0bdfd3ded2 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 16:16:09 +0200 Subject: [PATCH 089/154] WIP: use godzilla runner --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index ea2ed507..9ceb3937 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -29,6 +29,7 @@ cpu-basic-test image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY timeout: 10 hours script: + - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test - > docker build --pull -- GitLab From 6b13b3b544f44abcd0460b03da1d5be03c5b42d5 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 16:35:00 +0200 Subject: [PATCH 090/154] WIP: use godzilla runner --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index d1993d5b..343c4863 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ # OTBTF: Orfeo ToolBox meets TensorFlow [](https://opensource.org/licenses/Apache-2.0) +[](https://gitlab.irstea.fr/remi.cresson/otbtf/-/commits/develop) This remote module of the [Orfeo ToolBox](https://www.orfeo-toolbox.org) provides a generic, multi purpose deep learning framework, targeting remote sensing images processing. It contains a set of new process objects that internally invoke [Tensorflow](https://www.tensorflow.org/), and a bunch of user-oriented applications to perform deep learning with real-world remote sensing images. -- GitLab From 929dae88414026ac9dc760cb01d23e7d99a4a7ce Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Thu, 21 Apr 2022 16:55:58 +0200 Subject: [PATCH 091/154] ENH: generate samples of same type as initial raster --- python/otbtf.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index b28a1cc4..a1cf9bd4 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -58,8 +58,11 @@ def read_as_np_arr(gdal_ds, as_patches=True): False, the shape is (1, psz_y, psz_x, nb_channels) :return: Numpy array of dim 4 """ - buffer = gdal_ds.ReadAsArray() + gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64', + 10: 'complex64', 11: 'complex128'} + gdal_type = gdal_ds.GetRasterBand(1).DataType size_x = gdal_ds.RasterXSize + buffer = gdal_ds.ReadAsArray().astype(gdal_to_np_types[gdal_type]) if len(buffer.shape) == 3: buffer = np.transpose(buffer, axes=(1, 2, 0)) if not as_patches: @@ -68,7 +71,7 @@ def read_as_np_arr(gdal_ds, as_patches=True): else: n_elems = int(gdal_ds.RasterYSize / size_x) size_y = size_x - return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))) + return buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)) # -------------------------------------------------- Buffer class ------------------------------------------------------ @@ -244,8 +247,11 @@ class PatchesImagesReader(PatchesReaderBase): @staticmethod def _read_extract_as_np_arr(gdal_ds, offset): + gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64', + 10: 'complex64', 11: 'complex128'} assert gdal_ds is not None psz = gdal_ds.RasterXSize + gdal_type = gdal_ds.GetRasterBand(1).DataType yoff = int(offset * psz) assert yoff + psz <= gdal_ds.RasterYSize buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz) @@ -254,7 +260,7 @@ class PatchesImagesReader(PatchesReaderBase): else: # single-band raster buffer = np.expand_dims(buffer, axis=2) - return np.float32(buffer) + return buffer.astype(gdal_to_np_types[gdal_type]) def get_sample(self, index): """ -- GitLab From fd264b44acbab66866a365e2c29b7c09e28c7a08 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Thu, 21 Apr 2022 17:18:04 +0200 Subject: [PATCH 092/154] REFAC: rename python folder + split otbtf.py in several files --- {python => otbtf}/__init__.py | 0 {python => otbtf}/ckpt2savedmodel.py | 0 python/otbtf.py => otbtf/dataset.py | 249 +----------------- .../create_savedmodel_ienco-m3_patchbased.py | 0 .../create_savedmodel_maggiori17_fullyconv.py | 0 .../create_savedmodel_pxs_fcn.py | 0 .../create_savedmodel_simple_cnn.py | 0 .../create_savedmodel_simple_fcn.py | 0 .../examples/tensorflow_v2x/l2_norm.py | 0 .../examples/tensorflow_v2x/scalar_product.py | 0 otbtf/tfrecords.py | 208 +++++++++++++++ {python => otbtf}/tricks.py | 0 otbtf/utils.py | 35 +++ 13 files changed, 245 insertions(+), 247 deletions(-) rename {python => otbtf}/__init__.py (100%) rename {python => otbtf}/ckpt2savedmodel.py (100%) rename python/otbtf.py => otbtf/dataset.py (64%) rename {python => otbtf}/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py (100%) rename {python => otbtf}/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py (100%) rename {python => otbtf}/examples/tensorflow_v1x/create_savedmodel_pxs_fcn.py (100%) rename {python => otbtf}/examples/tensorflow_v1x/create_savedmodel_simple_cnn.py (100%) rename {python => otbtf}/examples/tensorflow_v1x/create_savedmodel_simple_fcn.py (100%) rename {python => otbtf}/examples/tensorflow_v2x/l2_norm.py (100%) rename {python => otbtf}/examples/tensorflow_v2x/scalar_product.py (100%) create mode 100644 otbtf/tfrecords.py rename {python => otbtf}/tricks.py (100%) create mode 100644 otbtf/utils.py diff --git a/python/__init__.py b/otbtf/__init__.py similarity index 100% rename from python/__init__.py rename to otbtf/__init__.py diff --git a/python/ckpt2savedmodel.py b/otbtf/ckpt2savedmodel.py similarity index 100% rename from python/ckpt2savedmodel.py rename to otbtf/ckpt2savedmodel.py diff --git a/python/otbtf.py b/otbtf/dataset.py similarity index 64% rename from python/otbtf.py rename to otbtf/dataset.py index b28a1cc4..4b0f945d 100644 --- a/python/otbtf.py +++ b/otbtf/dataset.py @@ -20,60 +20,19 @@ """ Contains stuff to help working with TensorFlow and geospatial data in the OTBTF framework. """ -import glob -import json -import os import threading import multiprocessing import time import logging from abc import ABC, abstractmethod -from functools import partial import numpy as np import tensorflow as tf -from osgeo import gdal -from tqdm import tqdm - - -# ----------------------------------------------------- Helpers -------------------------------------------------------- - - -def gdal_open(filename): - """ - Open a GDAL raster - :param filename: raster file - :return: a GDAL dataset instance - """ - gdal_ds = gdal.Open(filename) - if gdal_ds is None: - raise Exception("Unable to open file {}".format(filename)) - return gdal_ds - - -def read_as_np_arr(gdal_ds, as_patches=True): - """ - Read a GDAL raster as numpy array - :param gdal_ds: a GDAL dataset instance - :param as_patches: if True, the returned numpy array has the following shape (n, psz_x, psz_x, nb_channels). If - False, the shape is (1, psz_y, psz_x, nb_channels) - :return: Numpy array of dim 4 - """ - buffer = gdal_ds.ReadAsArray() - size_x = gdal_ds.RasterXSize - if len(buffer.shape) == 3: - buffer = np.transpose(buffer, axes=(1, 2, 0)) - if not as_patches: - n_elems = 1 - size_y = gdal_ds.RasterYSize - else: - n_elems = int(gdal_ds.RasterYSize / size_x) - size_y = size_x - return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))) +from otbtf.utils import read_as_np_arr, gdal_open +from otbtf.tfrecords import TFRecords # -------------------------------------------------- Buffer class ------------------------------------------------------ - class Buffer: """ Used to store and access list of objects @@ -106,7 +65,6 @@ class Buffer: # ---------------------------------------------- PatchesReaderBase class ----------------------------------------------- - class PatchesReaderBase(ABC): """ Base class for patches delivery @@ -151,7 +109,6 @@ class PatchesReaderBase(ABC): # --------------------------------------------- PatchesImagesReader class ---------------------------------------------- - class PatchesImagesReader(PatchesReaderBase): """ This class provides a read access to a set of patches images. @@ -327,7 +284,6 @@ class PatchesImagesReader(PatchesReaderBase): # ----------------------------------------------- IteratorBase class --------------------------------------------------- - class IteratorBase(ABC): """ Base class for iterators @@ -340,7 +296,6 @@ class IteratorBase(ABC): # ---------------------------------------------- RandomIterator class -------------------------------------------------- - class RandomIterator(IteratorBase): """ Pick a random number in the [0, handler.size) range. @@ -370,7 +325,6 @@ class RandomIterator(IteratorBase): # ------------------------------------------------- Dataset class ------------------------------------------------------ - class Dataset: """ Handles the "mining" of patches. @@ -532,7 +486,6 @@ class Dataset: # ----------------------------------------- DatasetFromPatchesImages class --------------------------------------------- - class DatasetFromPatchesImages(Dataset): """ Handles the "mining" of a set of patches images. @@ -559,202 +512,4 @@ class DatasetFromPatchesImages(Dataset): super().__init__(patches_reader=patches_reader, buffer_length=buffer_length, Iterator=Iterator) -class TFRecords: - """ - This class allows to convert Dataset objects to TFRecords and to load them in dataset tensorflows format. - """ - - def __init__(self, path): - """ - :param path: Can be a directory where TFRecords must be saved/loaded or a single TFRecord path - """ - if os.path.isdir(path) or not os.path.exists(path): - self.dirpath = path - os.makedirs(self.dirpath, exist_ok=True) - self.tfrecords_pattern_path = os.path.join(self.dirpath, "*.records") - else: - self.dirpath = os.path.dirname(path) - self.tfrecords_pattern_path = path - self.output_types_file = os.path.join(self.dirpath, "output_types.json") - self.output_shape_file = os.path.join(self.dirpath, "output_shape.json") - self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None - self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None - - @staticmethod - def _bytes_feature(value): - """ - Convert a value to a type compatible with tf.train.Example. - :param value: value - :return a bytes_list from a string / byte. - """ - if isinstance(value, type(tf.constant(0))): - value = value.numpy() # BytesList won't unpack a string from an EagerTensor. - return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) - - def ds2tfrecord(self, dataset, n_samples_per_shard=100, drop_remainder=True): - """ - Convert and save samples from dataset object to tfrecord files. - :param dataset: Dataset object to convert into a set of tfrecords - :param n_samples_per_shard: Number of samples per shard - :param drop_remainder: Whether additional samples should be dropped. Advisable if using multiworkers training. - If True, all TFRecords will have `n_samples_per_shard` samples - """ - logging.info("%s samples", dataset.size) - - nb_shards = (dataset.size // n_samples_per_shard) - if not drop_remainder and dataset.size % n_samples_per_shard > 0: - nb_shards += 1 - - self.convert_dataset_output_shapes(dataset) - - def _convert_data(data): - """ - Convert data - """ - data_converted = {} - - for k, d in data.items(): - data_converted[k] = d.name - return data_converted - - self.save(_convert_data(dataset.output_types), self.output_types_file) - - for i in tqdm(range(nb_shards)): - - if (i + 1) * n_samples_per_shard <= dataset.size: - nb_sample = n_samples_per_shard - else: - nb_sample = dataset.size - i * n_samples_per_shard - - filepath = os.path.join(self.dirpath, f"{i}.records") - with tf.io.TFRecordWriter(filepath) as writer: - for s in range(nb_sample): - sample = dataset.read_one_sample() - serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()} - features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in - serialized_sample.items()} - tf_features = tf.train.Features(feature=features) - example = tf.train.Example(features=tf_features) - writer.write(example.SerializeToString()) - - @staticmethod - def save(data, filepath): - """ - Save data to pickle format. - :param data: Data to save json format - :param filepath: Output file name - """ - - with open(filepath, 'w') as f: - json.dump(data, f, indent=4) - - @staticmethod - def load(filepath): - """ - Return data from pickle format. - :param filepath: Input file name - """ - with open(filepath, 'r') as f: - return json.load(f) - - def convert_dataset_output_shapes(self, dataset): - """ - Convert and save numpy shape to tensorflow shape. - :param dataset: Dataset object containing output shapes - """ - output_shapes = {} - - for key in dataset.output_shapes.keys(): - output_shapes[key] = (None,) + dataset.output_shapes[key] - - self.save(output_shapes, self.output_shape_file) - - @staticmethod - def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs): - """ - Parse example object to sample dict. - :param example: Example object to parse - :param features_types: List of types for each feature - :param target_keys: list of keys of the targets - :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns - a tuple (input_preprocessed, target_preprocessed) - :param kwargs: some keywords arguments for preprocessing_fn - """ - read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types} - example_parsed = tf.io.parse_single_example(example, read_features) - - for key in read_features.keys(): - example_parsed[key] = tf.io.parse_tensor(example_parsed[key], out_type=features_types[key]) - - # Differentiating inputs and outputs - input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys} - target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} - - if preprocessing_fn: - input_parsed, target_parsed = preprocessing_fn(input_parsed, target_parsed, **kwargs) - - return input_parsed, target_parsed - - def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None, - preprocessing_fn=None, **kwargs): - """ - Read all tfrecord files matching with pattern and convert data to tensorflow dataset. - :param batch_size: Size of tensorflow batch - :param target_keys: Keys of the target, e.g. ['s2_out'] - :param n_workers: number of workers, e.g. 4 if using 4 GPUs - e.g. 12 if using 3 nodes of 4 GPUs - :param drop_remainder: whether the last batch should be dropped in the case it has fewer than - `batch_size` elements. True is advisable when training on multiworkers. - False is advisable when evaluating metrics so that all samples are used - :param shuffle_buffer_size: if None, shuffle is not used. Else, blocks of shuffle_buffer_size - elements are shuffled using uniform random. - :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns - a tuple (input_preprocessed, target_preprocessed) - :param kwargs: some keywords arguments for preprocessing_fn - """ - options = tf.data.Options() - if shuffle_buffer_size: - options.experimental_deterministic = False # disable order, increase speed - options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO # for multiworker - parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys, - preprocessing_fn=preprocessing_fn, **kwargs) - - # TODO: to be investigated : - # 1/ num_parallel_reads useful ? I/O bottleneck of not ? - # 2/ num_parallel_calls=tf.data.experimental.AUTOTUNE useful ? - # 3/ shuffle or not shuffle ? - matching_files = glob.glob(self.tfrecords_pattern_path) - logging.info('Searching TFRecords in %s...', self.tfrecords_pattern_path) - logging.info('Number of matching TFRecords: %s', len(matching_files)) - matching_files = matching_files[:n_workers * (len(matching_files) // n_workers)] # files multiple of workers - nb_matching_files = len(matching_files) - if nb_matching_files == 0: - raise Exception("At least one worker has no TFRecord file in {}. Please ensure that the number of TFRecord " - "files is greater or equal than the number of workers!".format(self.tfrecords_pattern_path)) - logging.info('Reducing number of records to : %s', nb_matching_files) - dataset = tf.data.TFRecordDataset(matching_files) # , num_parallel_reads=2) # interleaves reads from xxx files - dataset = dataset.with_options(options) # uses data as soon as it streams in, rather than in its original order - dataset = dataset.map(parse, num_parallel_calls=tf.data.experimental.AUTOTUNE) - if shuffle_buffer_size: - dataset = dataset.shuffle(buffer_size=shuffle_buffer_size) - dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) - dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) - # TODO voir si on met le prefetch avant le batch cf https://keras.io/examples/keras_recipes/tfrecord/ - - return dataset - - def read_one_sample(self, target_keys): - """ - Read one tfrecord file matching with pattern and convert data to tensorflow dataset. - :param target_key: Key of the target, e.g. 's2_out' - """ - matching_files = glob.glob(self.tfrecords_pattern_path) - one_file = matching_files[0] - parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) - dataset = tf.data.TFRecordDataset(one_file) - dataset = dataset.map(parse) - dataset = dataset.batch(1) - - sample = iter(dataset).get_next() - return sample diff --git a/python/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py b/otbtf/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py similarity index 100% rename from python/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py rename to otbtf/examples/tensorflow_v1x/create_savedmodel_ienco-m3_patchbased.py diff --git a/python/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py b/otbtf/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py similarity index 100% rename from python/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py rename to otbtf/examples/tensorflow_v1x/create_savedmodel_maggiori17_fullyconv.py diff --git a/python/examples/tensorflow_v1x/create_savedmodel_pxs_fcn.py b/otbtf/examples/tensorflow_v1x/create_savedmodel_pxs_fcn.py similarity index 100% rename from python/examples/tensorflow_v1x/create_savedmodel_pxs_fcn.py rename to otbtf/examples/tensorflow_v1x/create_savedmodel_pxs_fcn.py diff --git a/python/examples/tensorflow_v1x/create_savedmodel_simple_cnn.py b/otbtf/examples/tensorflow_v1x/create_savedmodel_simple_cnn.py similarity index 100% rename from python/examples/tensorflow_v1x/create_savedmodel_simple_cnn.py rename to otbtf/examples/tensorflow_v1x/create_savedmodel_simple_cnn.py diff --git a/python/examples/tensorflow_v1x/create_savedmodel_simple_fcn.py b/otbtf/examples/tensorflow_v1x/create_savedmodel_simple_fcn.py similarity index 100% rename from python/examples/tensorflow_v1x/create_savedmodel_simple_fcn.py rename to otbtf/examples/tensorflow_v1x/create_savedmodel_simple_fcn.py diff --git a/python/examples/tensorflow_v2x/l2_norm.py b/otbtf/examples/tensorflow_v2x/l2_norm.py similarity index 100% rename from python/examples/tensorflow_v2x/l2_norm.py rename to otbtf/examples/tensorflow_v2x/l2_norm.py diff --git a/python/examples/tensorflow_v2x/scalar_product.py b/otbtf/examples/tensorflow_v2x/scalar_product.py similarity index 100% rename from python/examples/tensorflow_v2x/scalar_product.py rename to otbtf/examples/tensorflow_v2x/scalar_product.py diff --git a/otbtf/tfrecords.py b/otbtf/tfrecords.py new file mode 100644 index 00000000..17fa51e9 --- /dev/null +++ b/otbtf/tfrecords.py @@ -0,0 +1,208 @@ +import glob +import json +import os +import logging +from functools import partial +import tensorflow as tf +from tqdm import tqdm + + +class TFRecords: + """ + This class allows to convert Dataset objects to TFRecords and to load them in dataset tensorflows format. + """ + + def __init__(self, path): + """ + :param path: Can be a directory where TFRecords must be saved/loaded or a single TFRecord path + """ + if os.path.isdir(path) or not os.path.exists(path): + self.dirpath = path + os.makedirs(self.dirpath, exist_ok=True) + self.tfrecords_pattern_path = os.path.join(self.dirpath, "*.records") + else: + self.dirpath = os.path.dirname(path) + self.tfrecords_pattern_path = path + self.output_types_file = os.path.join(self.dirpath, "output_types.json") + self.output_shape_file = os.path.join(self.dirpath, "output_shape.json") + self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None + self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None + + @staticmethod + def _bytes_feature(value): + """ + Convert a value to a type compatible with tf.train.Example. + :param value: value + :return a bytes_list from a string / byte. + """ + if isinstance(value, type(tf.constant(0))): + value = value.numpy() # BytesList won't unpack a string from an EagerTensor. + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + def ds2tfrecord(self, dataset, n_samples_per_shard=100, drop_remainder=True): + """ + Convert and save samples from dataset object to tfrecord files. + :param dataset: Dataset object to convert into a set of tfrecords + :param n_samples_per_shard: Number of samples per shard + :param drop_remainder: Whether additional samples should be dropped. Advisable if using multiworkers training. + If True, all TFRecords will have `n_samples_per_shard` samples + """ + logging.info("%s samples", dataset.size) + + nb_shards = (dataset.size // n_samples_per_shard) + if not drop_remainder and dataset.size % n_samples_per_shard > 0: + nb_shards += 1 + + self.convert_dataset_output_shapes(dataset) + + def _convert_data(data): + """ + Convert data + """ + data_converted = {} + + for k, d in data.items(): + data_converted[k] = d.name + + return data_converted + + self.save(_convert_data(dataset.output_types), self.output_types_file) + + for i in tqdm(range(nb_shards)): + + if (i + 1) * n_samples_per_shard <= dataset.size: + nb_sample = n_samples_per_shard + else: + nb_sample = dataset.size - i * n_samples_per_shard + + filepath = os.path.join(self.dirpath, f"{i}.records") + with tf.io.TFRecordWriter(filepath) as writer: + for s in range(nb_sample): + sample = dataset.read_one_sample() + serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()} + features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in + serialized_sample.items()} + tf_features = tf.train.Features(feature=features) + example = tf.train.Example(features=tf_features) + writer.write(example.SerializeToString()) + + @staticmethod + def save(data, filepath): + """ + Save data to pickle format. + :param data: Data to save json format + :param filepath: Output file name + """ + + with open(filepath, 'w') as f: + json.dump(data, f, indent=4) + + @staticmethod + def load(filepath): + """ + Return data from pickle format. + :param filepath: Input file name + """ + with open(filepath, 'r') as f: + return json.load(f) + + def convert_dataset_output_shapes(self, dataset): + """ + Convert and save numpy shape to tensorflow shape. + :param dataset: Dataset object containing output shapes + """ + output_shapes = {} + + for key in dataset.output_shapes.keys(): + output_shapes[key] = (None,) + dataset.output_shapes[key] + + self.save(output_shapes, self.output_shape_file) + + @staticmethod + def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs): + """ + Parse example object to sample dict. + :param example: Example object to parse + :param features_types: List of types for each feature + :param target_keys: list of keys of the targets + :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns + a tuple (input_preprocessed, target_preprocessed) + :param kwargs: some keywords arguments for preprocessing_fn + """ + read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types} + example_parsed = tf.io.parse_single_example(example, read_features) + + for key in read_features.keys(): + example_parsed[key] = tf.io.parse_tensor(example_parsed[key], out_type=features_types[key]) + + # Differentiating inputs and outputs + input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys} + target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} + + if preprocessing_fn: + input_parsed, target_parsed = preprocessing_fn(input_parsed, target_parsed, **kwargs) + + return input_parsed, target_parsed + + def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None, + preprocessing_fn=None, **kwargs): + """ + Read all tfrecord files matching with pattern and convert data to tensorflow dataset. + :param batch_size: Size of tensorflow batch + :param target_keys: Keys of the target, e.g. ['s2_out'] + :param n_workers: number of workers, e.g. 4 if using 4 GPUs + e.g. 12 if using 3 nodes of 4 GPUs + :param drop_remainder: whether the last batch should be dropped in the case it has fewer than + `batch_size` elements. True is advisable when training on multiworkers. + False is advisable when evaluating metrics so that all samples are used + :param shuffle_buffer_size: if None, shuffle is not used. Else, blocks of shuffle_buffer_size + elements are shuffled using uniform random. + :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns + a tuple (input_preprocessed, target_preprocessed) + :param kwargs: some keywords arguments for preprocessing_fn + """ + options = tf.data.Options() + if shuffle_buffer_size: + options.experimental_deterministic = False # disable order, increase speed + options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO # for multiworker + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys, + preprocessing_fn=preprocessing_fn, **kwargs) + + # TODO: to be investigated : + # 1/ num_parallel_reads useful ? I/O bottleneck of not ? + # 2/ num_parallel_calls=tf.data.experimental.AUTOTUNE useful ? + # 3/ shuffle or not shuffle ? + matching_files = glob.glob(self.tfrecords_pattern_path) + logging.info('Searching TFRecords in %s...', self.tfrecords_pattern_path) + logging.info('Number of matching TFRecords: %s', len(matching_files)) + matching_files = matching_files[:n_workers * (len(matching_files) // n_workers)] # files multiple of workers + nb_matching_files = len(matching_files) + if nb_matching_files == 0: + raise Exception("At least one worker has no TFRecord file in {}. Please ensure that the number of TFRecord " + "files is greater or equal than the number of workers!".format(self.tfrecords_pattern_path)) + logging.info('Reducing number of records to : %s', nb_matching_files) + dataset = tf.data.TFRecordDataset(matching_files) # , num_parallel_reads=2) # interleaves reads from xxx files + dataset = dataset.with_options(options) # uses data as soon as it streams in, rather than in its original order + dataset = dataset.map(parse, num_parallel_calls=tf.data.experimental.AUTOTUNE) + if shuffle_buffer_size: + dataset = dataset.shuffle(buffer_size=shuffle_buffer_size) + dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) + dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) + # TODO voir si on met le prefetch avant le batch cf https://keras.io/examples/keras_recipes/tfrecord/ + + return dataset + + def read_one_sample(self, target_keys): + """ + Read one tfrecord file matching with pattern and convert data to tensorflow dataset. + :param target_key: Key of the target, e.g. 's2_out' + """ + matching_files = glob.glob(self.tfrecords_pattern_path) + one_file = matching_files[0] + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) + dataset = tf.data.TFRecordDataset(one_file) + dataset = dataset.map(parse) + dataset = dataset.batch(1) + + sample = iter(dataset).get_next() + return sample \ No newline at end of file diff --git a/python/tricks.py b/otbtf/tricks.py similarity index 100% rename from python/tricks.py rename to otbtf/tricks.py diff --git a/otbtf/utils.py b/otbtf/utils.py new file mode 100644 index 00000000..920b0dc6 --- /dev/null +++ b/otbtf/utils.py @@ -0,0 +1,35 @@ +from osgeo import gdal + +# ----------------------------------------------------- Helpers -------------------------------------------------------- + +def gdal_open(filename): + """ + Open a GDAL raster + :param filename: raster file + :return: a GDAL dataset instance + """ + gdal_ds = gdal.Open(filename) + if gdal_ds is None: + raise Exception("Unable to open file {}".format(filename)) + return gdal_ds + + +def read_as_np_arr(gdal_ds, as_patches=True): + """ + Read a GDAL raster as numpy array + :param gdal_ds: a GDAL dataset instance + :param as_patches: if True, the returned numpy array has the following shape (n, psz_x, psz_x, nb_channels). If + False, the shape is (1, psz_y, psz_x, nb_channels) + :return: Numpy array of dim 4 + """ + buffer = gdal_ds.ReadAsArray() + size_x = gdal_ds.RasterXSize + if len(buffer.shape) == 3: + buffer = np.transpose(buffer, axes=(1, 2, 0)) + if not as_patches: + n_elems = 1 + size_y = gdal_ds.RasterYSize + else: + n_elems = int(gdal_ds.RasterYSize / size_x) + size_y = size_x + return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))) \ No newline at end of file -- GitLab From e4f2f7ac31e61693741c03e25712e28053631ff0 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Thu, 21 Apr 2022 17:21:38 +0200 Subject: [PATCH 093/154] ENH: add setup.py --- setup.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 setup.py diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..3a95ac4a --- /dev/null +++ b/setup.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +import setuptools + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + +setuptools.setup( + name="otbtf", + version="0.1", + author="Remi Cresson", + author_email="remi.cresson@inrae.fr", + description="OTBTF: Orfeo ToolBox meets TensorFlow", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://gitlab.irstea.fr/remi.cresson/otbtf", + classifiers=[ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Topic :: Scientific/Engineering :: GIS", + "Topic :: Scientific/Engineering :: Image Processing", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + ], + packages=setuptools.find_packages(), + python_requires=">=3.6", + keywords="remote sensing, otb, orfeotoolbox, orfeo toolbox, tensorflow, tf, deep learning, machine learning", +) \ No newline at end of file -- GitLab From 30c277e03283810f02ad01b9804c1b2205666361 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 18:09:17 +0200 Subject: [PATCH 094/154] WIP: use godzilla runner --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 9ceb3937..fe55016f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -33,6 +33,7 @@ cpu-basic-test image: - > docker build --pull + --network="host" --cache-from $CI_REGISTRY_IMAGE:cpu-basic-test --label "org.opencontainers.image.title=$CI_PROJECT_TITLE" --label "org.opencontainers.image.url=$CI_PROJECT_URL" -- GitLab From 859e9fb47df1b4806c05ce29848aefa7ac706901 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 18:23:11 +0200 Subject: [PATCH 095/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index fe55016f..5e525bed 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -30,6 +30,7 @@ cpu-basic-test image: timeout: 10 hours script: - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test + - wget http://localhost:9090/status - > docker build --pull @@ -45,6 +46,7 @@ cpu-basic-test image: --build-arg KEEP_OTB_SRC="true" --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" + --build-arg BZL_OPTIONS="--verbose_failures --remote_cache=http://localhost:9090" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test -- GitLab From fb691055653247ae4404c29dabb03523f69c5952 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 18:25:18 +0200 Subject: [PATCH 096/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 5e525bed..749cf50a 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -29,8 +29,8 @@ cpu-basic-test image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY timeout: 10 hours script: + - wget http://127.0.0.1:9090/status - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test - - wget http://localhost:9090/status - > docker build --pull -- GitLab From 14b8456c3c52d12523db683ddb754ca4f838ebb0 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 18:27:32 +0200 Subject: [PATCH 097/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 749cf50a..75a40a10 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -29,7 +29,7 @@ cpu-basic-test image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY timeout: 10 hours script: - - wget http://127.0.0.1:9090/status + - wget http://docker:9090/status - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test - > docker build -- GitLab From 2bc2672022b17b0822e023a265c2a9b27856accf Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 18:35:09 +0200 Subject: [PATCH 098/154] WIP: use godzilla runner --- .gitlab-ci.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 75a40a10..2b3e57f2 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -29,7 +29,7 @@ cpu-basic-test image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY timeout: 10 hours script: - - wget http://docker:9090/status + - wget http://localhost:9090/status - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test - > docker build @@ -46,7 +46,6 @@ cpu-basic-test image: --build-arg KEEP_OTB_SRC="true" --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" - --build-arg BZL_OPTIONS="--verbose_failures --remote_cache=http://localhost:9090" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test -- GitLab From 6504a690084c08243c3791ce8fefa2e6fa6e4561 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 19:57:18 +0200 Subject: [PATCH 099/154] WIP: use godzilla runner --- .gitlab-ci.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 2b3e57f2..2d44501f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -30,12 +30,14 @@ cpu-basic-test image: timeout: 10 hours script: - wget http://localhost:9090/status - - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test + - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME || : + - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || : - > docker build --pull --network="host" --cache-from $CI_REGISTRY_IMAGE:cpu-basic-test + --cache-from $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME --label "org.opencontainers.image.title=$CI_PROJECT_TITLE" --label "org.opencontainers.image.url=$CI_PROJECT_URL" --label "org.opencontainers.image.created=$CI_JOB_STARTED_AT" @@ -48,8 +50,8 @@ cpu-basic-test image: --build-arg BASE_IMG="ubuntu:20.04" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test - - docker push $CI_REGISTRY_IMAGE:cpu-basic-test +# - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test +# - docker push $CI_REGISTRY_IMAGE:cpu-basic-test .static_analysis_base: stage: Static Analysis -- GitLab From 5d04d86b39ee87a3d10a917ad76109a808fc147a Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 20:02:10 +0200 Subject: [PATCH 100/154] WIP: use godzilla runner --- .gitlab-ci.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 2d44501f..68748f09 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -29,9 +29,8 @@ cpu-basic-test image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY timeout: 10 hours script: - - wget http://localhost:9090/status - - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME || : - - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || : + - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME || + - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || - > docker build --pull -- GitLab From 55c1375a7f9316c381d5891d8fd5a640322209f7 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 21:28:17 +0200 Subject: [PATCH 101/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 68748f09..81efee65 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -44,7 +44,7 @@ cpu-basic-test image: --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME --build-arg OTBTESTS="true" - --build-arg KEEP_OTB_SRC="true" + --build-arg KEEP_SRC_OTB="true" --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" . -- GitLab From dde281eaa79229b30ed9a4cf0d4535358dda44cc Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 21:36:50 +0200 Subject: [PATCH 102/154] WIP: use godzilla runner --- Dockerfile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Dockerfile b/Dockerfile index 28ee7875..4a4d3ab7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -85,6 +85,7 @@ RUN git clone --single-branch -b $TF https://github.com/tensorflow/tensorflow.gi ### OTB ARG GUI=false ARG OTB=7.4.0 +ARG OTBTESTS=false RUN mkdir /src/otb WORKDIR /src/otb @@ -97,6 +98,8 @@ RUN apt-get update -y \ && git clone --single-branch -b $OTB https://gitlab.orfeo-toolbox.org/orfeotoolbox/otb.git \ && mkdir -p build \ && cd build \ + && if $OTBTESTS; then \ + echo "-DBUILD_TESTING=ON" >> ../build-flags-otb.txt; fi \ # Set GL/Qt build flags && if $GUI; then \ sed -i -r "s/-DOTB_USE_(QT|OPENGL|GL[UFE][WT])=OFF/-DOTB_USE_\1=ON/" ../build-flags-otb.txt; fi \ -- GitLab From 7fb41f987d21b0c3a23795f96806a22222f9a523 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 08:29:49 +0200 Subject: [PATCH 103/154] WIP: use godzilla runner --- .gitlab-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 81efee65..bae051cc 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,4 @@ -image: $CI_REGISTRY_IMAGE:cpu-basic-test +image: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME variables: OTB_BUILD: /src/otb/build/OTB/build # Local OTB build directory @@ -18,7 +18,7 @@ stages: - Test - Applications Test -cpu-basic-test image: +test docker image: stage: Build allow_failure: false tags: [godzilla] -- GitLab From 1d55b2ea21bd6c96ac1016cfdf542b6b3a6882f0 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 08:39:17 +0200 Subject: [PATCH 104/154] WIP: use godzilla runner --- .gitlab-ci.yml | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index bae051cc..e4973ff3 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -10,8 +10,7 @@ variables: workflow: rules: - if: $CI_MERGE_REQUEST_ID # Execute jobs in merge request context - - if: $CI_COMMIT_BRANCH == 'develop' # Execute jobs when a new commit is pushed to develop branch - + stages: - Build - Static Analysis @@ -49,8 +48,6 @@ test docker image: --build-arg BASE_IMG="ubuntu:20.04" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME -# - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test -# - docker push $CI_REGISTRY_IMAGE:cpu-basic-test .static_analysis_base: stage: Static Analysis @@ -126,3 +123,11 @@ sr4rs: - git clone https://github.com/remicres/sr4rs.git - export PYTHONPATH=$PYTHONPATH:$PWD/sr4rs - python -m pytest --junitxml=$ARTIFACT_TEST_DIR/report_sr4rs.xml $OTBTF_SRC/test/sr4rs_unittest.py + +deploy: + only: + - master + script: + - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME + - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test + - docker push $CI_REGISTRY_IMAGE:cpu-basic-test -- GitLab From 3315b2bc8cf5cdd63c92ce9bf2154eb4bea6e09a Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 08:39:58 +0200 Subject: [PATCH 105/154] WIP: use godzilla runner --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index e4973ff3..eb4f1c04 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -16,6 +16,7 @@ stages: - Static Analysis - Test - Applications Test + - deploy test docker image: stage: Build -- GitLab From f9928f24b4905bb4ccec18250e47e9816183fe21 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 08:45:16 +0200 Subject: [PATCH 106/154] WIP: use godzilla runner --- .gitlab-ci.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index eb4f1c04..5f0cdd41 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -16,13 +16,15 @@ stages: - Static Analysis - Test - Applications Test - - deploy + - Ship test docker image: stage: Build allow_failure: false tags: [godzilla] image: docker/compose:latest + except: + - master services: - name: docker:dind before_script: @@ -126,6 +128,7 @@ sr4rs: - python -m pytest --junitxml=$ARTIFACT_TEST_DIR/report_sr4rs.xml $OTBTF_SRC/test/sr4rs_unittest.py deploy: + stage: Ship only: - master script: -- GitLab From fb8ea9dd621bb4ecf93ef35d85deb906edf10a9a Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 08:53:32 +0200 Subject: [PATCH 107/154] WIP: use godzilla runner --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 5f0cdd41..33398d14 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -49,6 +49,7 @@ test docker image: --build-arg KEEP_SRC_OTB="true" --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" + --build-arg BZL_OPTIONS="--verbose_failures --output_user_root=/bzl_cache" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME -- GitLab From bce03638a995ffbc36bb46e62d41a491093a9c07 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 09:19:11 +0200 Subject: [PATCH 108/154] WIP: use godzilla runner --- .gitlab-ci.yml | 3 ++- Dockerfile | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 33398d14..14eb7285 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -49,7 +49,8 @@ test docker image: --build-arg KEEP_SRC_OTB="true" --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" - --build-arg BZL_OPTIONS="--verbose_failures --output_user_root=/bzl_cache" + --build-arg BZL_PRE="--output_user_root=/bzl_cache" + --build-arg BZL_OPTIONS="--verbose_failures" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME diff --git a/Dockerfile b/Dockerfile index 4a4d3ab7..ee05b7a0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -51,6 +51,8 @@ ARG BZL_TARGETS="//tensorflow:libtensorflow_cc.so //tensorflow/tools/pip_package ARG BZL_CONFIGS="--config=nogcp --config=noaws --config=nohdfs --config=opt" # "--compilation_mode opt" is already enabled by default (see tf repo .bazelrc and configure.py) ARG BZL_OPTIONS="--verbose_failures --remote_cache=http://localhost:9090" +# options between "bazel" and "build" directive +ARG BZL_PRE="" # Build ARG ZIP_TF_BIN=false @@ -63,7 +65,7 @@ RUN git clone --single-branch -b $TF https://github.com/tensorflow/tensorflow.gi source ../build-env-tf.sh \ && ./configure \ && export TMP=/tmp/bazel \ - && BZL_CMD="build $BZL_TARGETS $BZL_CONFIGS $BZL_OPTIONS" \ + && BZL_CMD="$BZL_PRE build $BZL_TARGETS $BZL_CONFIGS $BZL_OPTIONS" \ && bazel $BZL_CMD --jobs="HOST_CPUS*$CPU_RATIO" ' \ # Installation - split here if you want to check files ^ #RUN cd tensorflow \ -- GitLab From a7a38bb1073cbfd1ca80411c8a7696529472d003 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 09:29:07 +0200 Subject: [PATCH 109/154] WIP: use godzilla runner --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 14eb7285..39c92082 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -49,7 +49,7 @@ test docker image: --build-arg KEEP_SRC_OTB="true" --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" - --build-arg BZL_PRE="--output_user_root=/bzl_cache" + --build-arg BZL_PRE="--output_user_root=/bzl_cache/cache" --build-arg BZL_OPTIONS="--verbose_failures" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME -- GitLab From 54e6939e1ec2d1e26397c4c1364d9af6a1cb3d1e Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 11:18:22 +0200 Subject: [PATCH 110/154] WIP: use godzilla runner --- .gitlab-ci.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 39c92082..6e7e8491 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -31,6 +31,9 @@ test docker image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY timeout: 10 hours script: + - ls / + - ls /bzl_cache/ + - ls /bzl_cache/cache - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME || - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || - > -- GitLab From e5c9d0f1a9865328d89db2402bad9ece0fba599b Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 11:27:16 +0200 Subject: [PATCH 111/154] WIP: use godzilla runner --- .gitlab-ci.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 6e7e8491..cef7e838 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -31,9 +31,8 @@ test docker image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY timeout: 10 hours script: - - ls / - - ls /bzl_cache/ - - ls /bzl_cache/cache + - ls -ll / + - ls -ll /bzl_cache/ - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME || - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || - > -- GitLab From 1e405a044724936d1dfee5ab690a66beb0f0795e Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 11:47:14 +0200 Subject: [PATCH 112/154] WIP: use godzilla runner --- .gitlab-ci.yml | 3 +-- Dockerfile | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index cef7e838..798ece96 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -33,6 +33,7 @@ test docker image: script: - ls -ll / - ls -ll /bzl_cache/ + - touch /bzl_cache/toto - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME || - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || - > @@ -51,8 +52,6 @@ test docker image: --build-arg KEEP_SRC_OTB="true" --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" - --build-arg BZL_PRE="--output_user_root=/bzl_cache/cache" - --build-arg BZL_OPTIONS="--verbose_failures" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME diff --git a/Dockerfile b/Dockerfile index ee05b7a0..4a4d3ab7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -51,8 +51,6 @@ ARG BZL_TARGETS="//tensorflow:libtensorflow_cc.so //tensorflow/tools/pip_package ARG BZL_CONFIGS="--config=nogcp --config=noaws --config=nohdfs --config=opt" # "--compilation_mode opt" is already enabled by default (see tf repo .bazelrc and configure.py) ARG BZL_OPTIONS="--verbose_failures --remote_cache=http://localhost:9090" -# options between "bazel" and "build" directive -ARG BZL_PRE="" # Build ARG ZIP_TF_BIN=false @@ -65,7 +63,7 @@ RUN git clone --single-branch -b $TF https://github.com/tensorflow/tensorflow.gi source ../build-env-tf.sh \ && ./configure \ && export TMP=/tmp/bazel \ - && BZL_CMD="$BZL_PRE build $BZL_TARGETS $BZL_CONFIGS $BZL_OPTIONS" \ + && BZL_CMD="build $BZL_TARGETS $BZL_CONFIGS $BZL_OPTIONS" \ && bazel $BZL_CMD --jobs="HOST_CPUS*$CPU_RATIO" ' \ # Installation - split here if you want to check files ^ #RUN cd tensorflow \ -- GitLab From f1bda5b1336a6a8a1d63204175e11a885b635665 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 16:29:51 +0200 Subject: [PATCH 113/154] REFAC: gdal_to_np_type is global constant --- python/otbtf.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index a1cf9bd4..a58d10d2 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -34,6 +34,19 @@ import tensorflow as tf from osgeo import gdal from tqdm import tqdm +# --------------------------------------------- GDAL to numpy types ---------------------------------------------------- + + +gdal_to_np_types = {1: 'uint8', + 2: 'uint16', + 3: 'int16', + 4: 'uint32', + 5: 'int32', + 6: 'float32', + 7: 'float64', + 10: 'complex64', + 11: 'complex128'} + # ----------------------------------------------------- Helpers -------------------------------------------------------- @@ -58,8 +71,6 @@ def read_as_np_arr(gdal_ds, as_patches=True): False, the shape is (1, psz_y, psz_x, nb_channels) :return: Numpy array of dim 4 """ - gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64', - 10: 'complex64', 11: 'complex128'} gdal_type = gdal_ds.GetRasterBand(1).DataType size_x = gdal_ds.RasterXSize buffer = gdal_ds.ReadAsArray().astype(gdal_to_np_types[gdal_type]) @@ -247,8 +258,6 @@ class PatchesImagesReader(PatchesReaderBase): @staticmethod def _read_extract_as_np_arr(gdal_ds, offset): - gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64', - 10: 'complex64', 11: 'complex128'} assert gdal_ds is not None psz = gdal_ds.RasterXSize gdal_type = gdal_ds.GetRasterBand(1).DataType -- GitLab From ae46d5e19cc49714b941bd749f3d876a06ee0a0f Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 16:35:31 +0200 Subject: [PATCH 114/154] REFAC: pylint --- python/otbtf.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index a58d10d2..2083250a 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -37,7 +37,7 @@ from tqdm import tqdm # --------------------------------------------- GDAL to numpy types ---------------------------------------------------- -gdal_to_np_types = {1: 'uint8', +GDAL_TO_NP_TYPES = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', @@ -73,7 +73,7 @@ def read_as_np_arr(gdal_ds, as_patches=True): """ gdal_type = gdal_ds.GetRasterBand(1).DataType size_x = gdal_ds.RasterXSize - buffer = gdal_ds.ReadAsArray().astype(gdal_to_np_types[gdal_type]) + buffer = gdal_ds.ReadAsArray().astype(GDAL_TO_NP_TYPES[gdal_type]) if len(buffer.shape) == 3: buffer = np.transpose(buffer, axes=(1, 2, 0)) if not as_patches: @@ -269,7 +269,7 @@ class PatchesImagesReader(PatchesReaderBase): else: # single-band raster buffer = np.expand_dims(buffer, axis=2) - return buffer.astype(gdal_to_np_types[gdal_type]) + return buffer.astype(GDAL_TO_NP_TYPES[gdal_type]) def get_sample(self, index): """ @@ -628,8 +628,8 @@ class TFRecords: """ data_converted = {} - for k, d in data.items(): - data_converted[k] = d.name + for key, value in data.items(): + data_converted[key] = value.name return data_converted @@ -644,7 +644,7 @@ class TFRecords: filepath = os.path.join(self.dirpath, f"{i}.records") with tf.io.TFRecordWriter(filepath) as writer: - for s in range(nb_sample): + for _ in range(nb_sample): sample = dataset.read_one_sample() serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()} features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in @@ -661,8 +661,8 @@ class TFRecords: :param filepath: Output file name """ - with open(filepath, 'w') as f: - json.dump(data, f, indent=4) + with open(filepath, 'w') as file: + json.dump(data, file, indent=4) @staticmethod def load(filepath): -- GitLab From b1f064f0f11614ef8e40e9f168b50080d1ad2e14 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 16:39:27 +0200 Subject: [PATCH 115/154] REFAC: pylint --- python/otbtf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 2083250a..d7e1a0b0 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -670,8 +670,8 @@ class TFRecords: Return data from pickle format. :param filepath: Input file name """ - with open(filepath, 'r') as f: - return json.load(f) + with open(filepath, 'r') as file: + return json.load(file) def convert_dataset_output_shapes(self, dataset): """ @@ -680,8 +680,8 @@ class TFRecords: """ output_shapes = {} - for key in dataset.output_shapes.keys(): - output_shapes[key] = (None,) + dataset.output_shapes[key] + for key, value in dataset.output_shapes.keys(): + output_shapes[key] = (None,) + value self.save(output_shapes, self.output_shape_file) -- GitLab From 9277f4eddd83a350364f0e68168371304bd36f9d Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 18:18:12 +0200 Subject: [PATCH 116/154] CI: change remote cache server --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 798ece96..a855105f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -52,6 +52,7 @@ test docker image: --build-arg KEEP_SRC_OTB="true" --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" + --build-arg BZL_OPTIONS="--verbose_failures --remote_cache=http://172.17.0.1:9090" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME -- GitLab From 846f370086884406ba37beb55796953bc0dd05d9 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 20:11:00 +0200 Subject: [PATCH 117/154] CI: remove dummy scripts --- .gitlab-ci.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a855105f..c4c1568c 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -31,9 +31,6 @@ test docker image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY timeout: 10 hours script: - - ls -ll / - - ls -ll /bzl_cache/ - - touch /bzl_cache/toto - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME || - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || - > -- GitLab From 061fcb65e4036d3210fccb0dbf9e4f6a3959dad4 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 20:11:20 +0200 Subject: [PATCH 118/154] ENH: remove useless astype() conversion --- otbtf/dataset.py | 11 ++++------- otbtf/utils.py | 21 ++++----------------- 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/otbtf/dataset.py b/otbtf/dataset.py index fc108531..2350ed80 100644 --- a/otbtf/dataset.py +++ b/otbtf/dataset.py @@ -27,7 +27,7 @@ import logging from abc import ABC, abstractmethod import numpy as np import tensorflow as tf -from otbtf.utils import read_as_np_arr, gdal_open, GDAL_TO_NP_TYPES +from otbtf.utils import read_as_np_arr, gdal_open from otbtf.tfrecords import TFRecords @@ -203,16 +203,13 @@ class PatchesImagesReader(PatchesReaderBase): def _read_extract_as_np_arr(gdal_ds, offset): assert gdal_ds is not None psz = gdal_ds.RasterXSize - gdal_type = gdal_ds.GetRasterBand(1).DataType yoff = int(offset * psz) assert yoff + psz <= gdal_ds.RasterYSize buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz) if len(buffer.shape) == 3: - buffer = np.transpose(buffer, axes=(1, 2, 0)) - else: # single-band raster - buffer = np.expand_dims(buffer, axis=2) - - return buffer.astype(GDAL_TO_NP_TYPES[gdal_type]) + # multi-band raster + return np.transpose(buffer, axes=(1, 2, 0)) + return np.expand_dims(buffer, axis=2) def get_sample(self, index): """ diff --git a/otbtf/utils.py b/otbtf/utils.py index 28677ce7..f1e803d9 100644 --- a/otbtf/utils.py +++ b/otbtf/utils.py @@ -1,19 +1,6 @@ from osgeo import gdal import numpy as np -# --------------------------------------------- GDAL to numpy types ---------------------------------------------------- - - -GDAL_TO_NP_TYPES = {1: 'uint8', - 2: 'uint16', - 3: 'int16', - 4: 'uint32', - 5: 'int32', - 6: 'float32', - 7: 'float64', - 10: 'complex64', - 11: 'complex128'} - # ----------------------------------------------------- Helpers -------------------------------------------------------- @@ -41,10 +28,10 @@ def read_as_np_arr(gdal_ds, as_patches=True): size_x = gdal_ds.RasterXSize if len(buffer.shape) == 3: buffer = np.transpose(buffer, axes=(1, 2, 0)) - if not as_patches: - n_elems = 1 - size_y = gdal_ds.RasterYSize - else: + if as_patches: n_elems = int(gdal_ds.RasterYSize / size_x) size_y = size_x + else: + n_elems = 1 + size_y = gdal_ds.RasterYSize return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))) -- GitLab From 92fc8d04745c4cb30511ed85594eae85624a326c Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 20:14:10 +0200 Subject: [PATCH 119/154] CI: job name --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c4c1568c..89d8ed57 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -18,7 +18,7 @@ stages: - Applications Test - Ship -test docker image: +docker image: stage: Build allow_failure: false tags: [godzilla] -- GitLab From cd4d8fecef283e76b6ead1c6fb3620aad323a102 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 20:23:23 +0200 Subject: [PATCH 120/154] REFAC: TFRecords --- otbtf/tfrecords.py | 29 ++++------------------------- 1 file changed, 4 insertions(+), 25 deletions(-) diff --git a/otbtf/tfrecords.py b/otbtf/tfrecords.py index 17fa51e9..889fd898 100644 --- a/otbtf/tfrecords.py +++ b/otbtf/tfrecords.py @@ -53,20 +53,11 @@ class TFRecords: if not drop_remainder and dataset.size % n_samples_per_shard > 0: nb_shards += 1 - self.convert_dataset_output_shapes(dataset) - - def _convert_data(data): - """ - Convert data - """ - data_converted = {} - - for k, d in data.items(): - data_converted[k] = d.name - - return data_converted + output_shapes = {key: (None,) + output_shape for key, output_shape in dataset.output_shapes.items()} + self.save(output_shapes, self.output_shape_file) - self.save(_convert_data(dataset.output_types), self.output_types_file) + output_types = {key: output_type.name for key, output_type in dataset.output_types.items()} + self.save(output_types, self.output_types_file) for i in tqdm(range(nb_shards)): @@ -106,18 +97,6 @@ class TFRecords: with open(filepath, 'r') as f: return json.load(f) - def convert_dataset_output_shapes(self, dataset): - """ - Convert and save numpy shape to tensorflow shape. - :param dataset: Dataset object containing output shapes - """ - output_shapes = {} - - for key in dataset.output_shapes.keys(): - output_shapes[key] = (None,) + dataset.output_shapes[key] - - self.save(output_shapes, self.output_shape_file) - @staticmethod def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs): """ -- GitLab From 4faa71a41120058d919f038761f0b1dd0f018343 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 21:53:56 +0200 Subject: [PATCH 121/154] REFAC: TFRecords --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 89d8ed57..d3daf916 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -49,7 +49,7 @@ docker image: --build-arg KEEP_SRC_OTB="true" --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" - --build-arg BZL_OPTIONS="--verbose_failures --remote_cache=http://172.17.0.1:9090" + --build-arg BZL_OPTIONS="--verbose_failures --remote_cache=172.17.0.1:9090" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME -- GitLab From 6809cd62f1136327af4e55d07f6965f53647d673 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 21:55:25 +0200 Subject: [PATCH 122/154] CI: remove labels --- .gitlab-ci.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d3daf916..3cb141be 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -39,11 +39,6 @@ docker image: --network="host" --cache-from $CI_REGISTRY_IMAGE:cpu-basic-test --cache-from $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - --label "org.opencontainers.image.title=$CI_PROJECT_TITLE" - --label "org.opencontainers.image.url=$CI_PROJECT_URL" - --label "org.opencontainers.image.created=$CI_JOB_STARTED_AT" - --label "org.opencontainers.image.revision=$CI_COMMIT_SHA" - --label "org.opencontainers.image.version=$CI_COMMIT_REF_NAME" --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME --build-arg OTBTESTS="true" --build-arg KEEP_SRC_OTB="true" -- GitLab From 6b4e114bd3d945b854b2a8abc1b0ba75177bbfcb Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 22:04:51 +0200 Subject: [PATCH 123/154] CI: patch dockerfile --- .gitlab-ci.yml | 2 -- Dockerfile | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3cb141be..6c738a64 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -35,7 +35,6 @@ docker image: - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || - > docker build - --pull --network="host" --cache-from $CI_REGISTRY_IMAGE:cpu-basic-test --cache-from $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME @@ -44,7 +43,6 @@ docker image: --build-arg KEEP_SRC_OTB="true" --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" - --build-arg BZL_OPTIONS="--verbose_failures --remote_cache=172.17.0.1:9090" . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME diff --git a/Dockerfile b/Dockerfile index 223a14c6..e1c48f6f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -148,7 +148,7 @@ COPY --from=builder /src /src # System-wide ENV ENV PATH="/opt/otbtf/bin:$PATH" ENV LD_LIBRARY_PATH="/opt/otbtf/lib:$LD_LIBRARY_PATH" -ENV PYTHONPATH="/opt/otbtf/lib/python3/site-packages:/opt/otbtf/lib/otb/python:/src/otbtf/python" +ENV PYTHONPATH="/opt/otbtf/lib/python3/site-packages:/opt/otbtf/lib/otb/python:/src/otbtf/otbtf" ENV OTB_APPLICATION_PATH="/opt/otbtf/lib/otb/applications" # Default user, directory and command (bash is the entrypoint when using 'docker create') @@ -169,6 +169,6 @@ USER otbuser # User-only ENV # Test python imports -RUN python -c "import tensorflow" -RUN python -c "import otbtf, tricks" -RUN python -c "import otbApplication as otb; otb.Registry.CreateApplication('ImageClassifierFromDeepFeatures')" +#RUN python -c "import tensorflow" +#RUN python -c "import otbtf, tricks" +#RUN python -c "import otbApplication as otb; otb.Registry.CreateApplication('ImageClassifierFromDeepFeatures')" -- GitLab From a4a78ed154e56e67d30a7aed23b4118fe4563fa9 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 22:19:54 +0200 Subject: [PATCH 124/154] FIX: update PYTHONPATH --- Dockerfile | 8 ++++---- {otbtf => tricks}/ckpt2savedmodel.py | 0 {otbtf => tricks}/tricks.py | 0 3 files changed, 4 insertions(+), 4 deletions(-) rename {otbtf => tricks}/ckpt2savedmodel.py (100%) rename {otbtf => tricks}/tricks.py (100%) diff --git a/Dockerfile b/Dockerfile index e1c48f6f..d5a644f7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -148,7 +148,7 @@ COPY --from=builder /src /src # System-wide ENV ENV PATH="/opt/otbtf/bin:$PATH" ENV LD_LIBRARY_PATH="/opt/otbtf/lib:$LD_LIBRARY_PATH" -ENV PYTHONPATH="/opt/otbtf/lib/python3/site-packages:/opt/otbtf/lib/otb/python:/src/otbtf/otbtf" +ENV PYTHONPATH="/opt/otbtf/lib/python3/site-packages:/opt/otbtf/lib/otb/python:/src/otbtf" ENV OTB_APPLICATION_PATH="/opt/otbtf/lib/otb/applications" # Default user, directory and command (bash is the entrypoint when using 'docker create') @@ -169,6 +169,6 @@ USER otbuser # User-only ENV # Test python imports -#RUN python -c "import tensorflow" -#RUN python -c "import otbtf, tricks" -#RUN python -c "import otbApplication as otb; otb.Registry.CreateApplication('ImageClassifierFromDeepFeatures')" +RUN python -c "import tensorflow" +RUN python -c "import otbtf, tricks" +RUN python -c "import otbApplication as otb; otb.Registry.CreateApplication('ImageClassifierFromDeepFeatures')" diff --git a/otbtf/ckpt2savedmodel.py b/tricks/ckpt2savedmodel.py similarity index 100% rename from otbtf/ckpt2savedmodel.py rename to tricks/ckpt2savedmodel.py diff --git a/otbtf/tricks.py b/tricks/tricks.py similarity index 100% rename from otbtf/tricks.py rename to tricks/tricks.py -- GitLab From d81dc370e6957b2b430021dbbe3eb8989ac03313 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 22:20:13 +0200 Subject: [PATCH 125/154] ADD: update __init__ --- otbtf/__init__.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/otbtf/__init__.py b/otbtf/__init__.py index e69de29b..8a6951a6 100644 --- a/otbtf/__init__.py +++ b/otbtf/__init__.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# ========================================================================== +# +# Copyright 2018-2019 IRSTEA +# Copyright 2020-2022 INRAE +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ==========================================================================*/ +""" +OTBTF python module +""" +from utils import read_as_np_arr, gdal_open +from dataset import Buffer, PatchesReaderBase, PatchesImagesReader, IteratorBase, RandomIterator, Dataset, \ + DatasetFromPatchesImages +from tfrecords import TFRecords \ No newline at end of file -- GitLab From 07e436997330fc7bd8c52f312ae657a34bacca44 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 22:20:28 +0200 Subject: [PATCH 126/154] ADD: trick for deprecated stuff --- tricks/tricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tricks/tricks.py b/tricks/tricks.py index b31b14c3..d22e7e96 100644 --- a/tricks/tricks.py +++ b/tricks/tricks.py @@ -25,7 +25,7 @@ for TF 1.X versions. """ import tensorflow.compat.v1 as tf from deprecated import deprecated -from otbtf import gdal_open, read_as_np_arr as read_as_np_arr_from_gdal_ds +from otbtf.utils import gdal_open, read_as_np_arr as read_as_np_arr_from_gdal_ds tf.disable_v2_behavior() -- GitLab From ca5c11caa9edc50847d124112fc07f6cbee9dd6b Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 22:20:53 +0200 Subject: [PATCH 127/154] ADD: .idea --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index a29689f2..1ef65aa1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ # Compiled python source # *.pyc +.idea -- GitLab From 95658a6a85653ff0ec66eb10502e1d55f791e140 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 22:27:47 +0200 Subject: [PATCH 128/154] FIX: imports --- otbtf/__init__.py | 6 +++--- tricks/ckpt2savedmodel.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/otbtf/__init__.py b/otbtf/__init__.py index 8a6951a6..77f806b8 100644 --- a/otbtf/__init__.py +++ b/otbtf/__init__.py @@ -20,7 +20,7 @@ """ OTBTF python module """ -from utils import read_as_np_arr, gdal_open -from dataset import Buffer, PatchesReaderBase, PatchesImagesReader, IteratorBase, RandomIterator, Dataset, \ +from otbtf.utils import read_as_np_arr, gdal_open +from otbtf.dataset import Buffer, PatchesReaderBase, PatchesImagesReader, IteratorBase, RandomIterator, Dataset, \ DatasetFromPatchesImages -from tfrecords import TFRecords \ No newline at end of file +from otbtf.tfrecords import TFRecords \ No newline at end of file diff --git a/tricks/ckpt2savedmodel.py b/tricks/ckpt2savedmodel.py index 117203ba..ff22965f 100755 --- a/tricks/ckpt2savedmodel.py +++ b/tricks/ckpt2savedmodel.py @@ -26,7 +26,7 @@ can be more conveniently exported as SavedModel (see how to build a model with keras in Tensorflow 2). """ import argparse -from tricks import ckpt_to_savedmodel +from tricks.tricks import ckpt_to_savedmodel def main(): -- GitLab From 1a22e0dc14ced66a0fb1ebc04bb3f107de77dc62 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 22:30:19 +0200 Subject: [PATCH 129/154] FIX: imports --- tricks/{tricks.py => __init__.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tricks/{tricks.py => __init__.py} (100%) diff --git a/tricks/tricks.py b/tricks/__init__.py similarity index 100% rename from tricks/tricks.py rename to tricks/__init__.py -- GitLab From ebb3a94094c6e32df7a8ffb764237c30f9ae1ca2 Mon Sep 17 00:00:00 2001 From: Cresson Remi <remi.cresson@irstea.fr> Date: Sat, 23 Apr 2022 03:31:21 +0200 Subject: [PATCH 130/154] CI: Update .gitlab-ci.yml --- .gitlab-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 6c738a64..2c5eb0c4 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -54,13 +54,13 @@ flake8: extends: .static_analysis_base script: - sudo apt update && sudo apt install flake8 -y - - python -m flake8 --max-line-length=120 $OTBTF_SRC/python + - python -m flake8 --max-line-length=120 $OTBTF_SRC/otbtf pylint: extends: .static_analysis_base script: - sudo apt update && sudo apt install pylint -y - - pylint --disable=too-many-nested-blocks,too-many-locals,too-many-statements,too-few-public-methods,too-many-instance-attributes,too-many-arguments --ignored-modules=tensorflow --max-line-length=120 --logging-format-style=new $OTBTF_SRC/python + - pylint --disable=too-many-nested-blocks,too-many-locals,too-many-statements,too-few-public-methods,too-many-instance-attributes,too-many-arguments --ignored-modules=tensorflow --max-line-length=120 --logging-format-style=new $OTBTF_SRC/otbtf codespell: extends: .static_analysis_base -- GitLab From e3b92f4b767aee5fb3aa5befd932cccb0704f0ce Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Sat, 23 Apr 2022 09:50:02 +0200 Subject: [PATCH 131/154] CI: change tgt branch --- .gitlab-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 2c5eb0c4..b640de0d 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -24,7 +24,7 @@ docker image: tags: [godzilla] image: docker/compose:latest except: - - master + - develop services: - name: docker:dind before_script: @@ -124,7 +124,7 @@ sr4rs: deploy: stage: Ship only: - - master + - develop script: - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test -- GitLab From 6b89dc17845abb215ba75a130b1236ceafa14732 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Sat, 23 Apr 2022 09:50:18 +0200 Subject: [PATCH 132/154] STYLE: fix pylint/flake8 --- otbtf/__init__.py | 6 +++--- otbtf/tfrecords.py | 26 ++++++++++++++++++++++++-- otbtf/utils.py | 22 ++++++++++++++++++++++ 3 files changed, 49 insertions(+), 5 deletions(-) diff --git a/otbtf/__init__.py b/otbtf/__init__.py index 77f806b8..46cbd935 100644 --- a/otbtf/__init__.py +++ b/otbtf/__init__.py @@ -20,7 +20,7 @@ """ OTBTF python module """ -from otbtf.utils import read_as_np_arr, gdal_open +from otbtf.utils import read_as_np_arr, gdal_open # noqa: 401 from otbtf.dataset import Buffer, PatchesReaderBase, PatchesImagesReader, IteratorBase, RandomIterator, Dataset, \ - DatasetFromPatchesImages -from otbtf.tfrecords import TFRecords \ No newline at end of file + DatasetFromPatchesImages # noqa: 401 +from otbtf.tfrecords import TFRecords # noqa: 401 diff --git a/otbtf/tfrecords.py b/otbtf/tfrecords.py index 889fd898..4aca2880 100644 --- a/otbtf/tfrecords.py +++ b/otbtf/tfrecords.py @@ -1,3 +1,25 @@ +# -*- coding: utf-8 -*- +# ========================================================================== +# +# Copyright 2018-2019 IRSTEA +# Copyright 2020-2022 INRAE +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ==========================================================================*/ +""" +The tfrecords module provides an implementation for the TFRecords files read/write +""" import glob import json import os @@ -68,7 +90,7 @@ class TFRecords: filepath = os.path.join(self.dirpath, f"{i}.records") with tf.io.TFRecordWriter(filepath) as writer: - for s in range(nb_sample): + for _ in range(nb_sample): sample = dataset.read_one_sample() serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()} features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in @@ -184,4 +206,4 @@ class TFRecords: dataset = dataset.batch(1) sample = iter(dataset).get_next() - return sample \ No newline at end of file + return sample diff --git a/otbtf/utils.py b/otbtf/utils.py index f1e803d9..5694772b 100644 --- a/otbtf/utils.py +++ b/otbtf/utils.py @@ -1,3 +1,25 @@ +# -*- coding: utf-8 -*- +# ========================================================================== +# +# Copyright 2018-2019 IRSTEA +# Copyright 2020-2022 INRAE +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ==========================================================================*/ +""" +The utils module provides some helpers to read patches using gdal +""" from osgeo import gdal import numpy as np -- GitLab From 986f8af898bab64ed3f372557c675273982d2947 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Sat, 23 Apr 2022 09:58:08 +0200 Subject: [PATCH 133/154] CI: change network --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index b640de0d..6db57c4c 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -35,7 +35,7 @@ docker image: - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || - > docker build - --network="host" + --network="gitlab-runner-net" --cache-from $CI_REGISTRY_IMAGE:cpu-basic-test --cache-from $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME -- GitLab From d6d354e64cba6f956217ce98c3bc9266b07e17e8 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Sat, 23 Apr 2022 13:22:36 +0200 Subject: [PATCH 134/154] CI: change network --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 6db57c4c..b640de0d 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -35,7 +35,7 @@ docker image: - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || - > docker build - --network="gitlab-runner-net" + --network="host" --cache-from $CI_REGISTRY_IMAGE:cpu-basic-test --cache-from $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME -- GitLab From 946c1414576751d624ff3d00834001f75f321a72 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Sun, 24 Apr 2022 14:06:51 +0200 Subject: [PATCH 135/154] CI: use buildkit to optimize cache for multistage build --- .gitlab-ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index b640de0d..c84d5090 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -34,7 +34,7 @@ docker image: - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME || - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || - > - docker build + DOCKER_BUILDKIT=1 docker build --network="host" --cache-from $CI_REGISTRY_IMAGE:cpu-basic-test --cache-from $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME @@ -43,6 +43,7 @@ docker image: --build-arg KEEP_SRC_OTB="true" --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" + --build-arg BUILDKIT_INLINE_CACHE=1 . - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME -- GitLab From 9af2be64335c2098db52a60bde31a325a8040bd9 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Sun, 24 Apr 2022 19:52:50 +0200 Subject: [PATCH 136/154] STYLE: pylint, flake8 fixes --- .gitlab-ci.yml | 4 ++-- otbtf/__init__.py | 6 +++--- otbtf/tfrecords.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c84d5090..87964c83 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -55,13 +55,13 @@ flake8: extends: .static_analysis_base script: - sudo apt update && sudo apt install flake8 -y - - python -m flake8 --max-line-length=120 $OTBTF_SRC/otbtf + - python -m flake8 --max-line-length=120 --per-file-ignores="__init__.py:F401" $OTBTF_SRC/otbtf pylint: extends: .static_analysis_base script: - sudo apt update && sudo apt install pylint -y - - pylint --disable=too-many-nested-blocks,too-many-locals,too-many-statements,too-few-public-methods,too-many-instance-attributes,too-many-arguments --ignored-modules=tensorflow --max-line-length=120 --logging-format-style=new $OTBTF_SRC/otbtf + - pylint --logging-format-style=old --disable=too-many-nested-blocks,too-many-locals,too-many-statements,too-few-public-methods,too-many-instance-attributes,too-many-arguments --ignored-modules=tensorflow --max-line-length=120 --logging-format-style=new $OTBTF_SRC/otbtf codespell: extends: .static_analysis_base diff --git a/otbtf/__init__.py b/otbtf/__init__.py index 46cbd935..ac36018a 100644 --- a/otbtf/__init__.py +++ b/otbtf/__init__.py @@ -20,7 +20,7 @@ """ OTBTF python module """ -from otbtf.utils import read_as_np_arr, gdal_open # noqa: 401 +from otbtf.utils import read_as_np_arr, gdal_open from otbtf.dataset import Buffer, PatchesReaderBase, PatchesImagesReader, IteratorBase, RandomIterator, Dataset, \ - DatasetFromPatchesImages # noqa: 401 -from otbtf.tfrecords import TFRecords # noqa: 401 + DatasetFromPatchesImages +from otbtf.tfrecords import TFRecords diff --git a/otbtf/tfrecords.py b/otbtf/tfrecords.py index 4aca2880..323e80cf 100644 --- a/otbtf/tfrecords.py +++ b/otbtf/tfrecords.py @@ -107,8 +107,8 @@ class TFRecords: :param filepath: Output file name """ - with open(filepath, 'w') as f: - json.dump(data, f, indent=4) + with open(filepath, 'w') as file: + json.dump(data, file, indent=4) @staticmethod def load(filepath): @@ -116,8 +116,8 @@ class TFRecords: Return data from pickle format. :param filepath: Input file name """ - with open(filepath, 'r') as f: - return json.load(f) + with open(filepath, 'r') as file: + return json.load(file) @staticmethod def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs): -- GitLab From 9696ddd9745c55a100f011cc6939752d730317b5 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Sun, 24 Apr 2022 21:39:37 +0200 Subject: [PATCH 137/154] ADD: simplify TFRecords --- otbtf/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/otbtf/dataset.py b/otbtf/dataset.py index 2350ed80..00275481 100644 --- a/otbtf/dataset.py +++ b/otbtf/dataset.py @@ -171,9 +171,9 @@ class PatchesImagesReader(PatchesReaderBase): else: if self.nb_of_channels[src_key] != gdal_ds.RasterCount: raise Exception("All patches images from one source must have the same number of channels!" - "Error happened for source: {}".format(src_key)) + f"Error happened for source: {src_key}") if len(set(nb_of_patches.values())) != 1: - raise Exception("Sources must have the same number of patches! Number of patches: {}".format(nb_of_patches)) + raise Exception(f"Sources must have the same number of patches! Number of patches: {nb_of_patches}") # gdal_ds sizes src_key_0 = list(self.gdal_ds)[0] # first key -- GitLab From cf97f216cdb7a491b9f7d8d591fd308a963c7b7a Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Sun, 24 Apr 2022 21:39:49 +0200 Subject: [PATCH 138/154] ADD: simplify TFRecords --- otbtf/tfrecords.py | 41 ++++++++++------------------------------- 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/otbtf/tfrecords.py b/otbtf/tfrecords.py index 323e80cf..b2aae0b2 100644 --- a/otbtf/tfrecords.py +++ b/otbtf/tfrecords.py @@ -36,15 +36,10 @@ class TFRecords: def __init__(self, path): """ - :param path: Can be a directory where TFRecords must be saved/loaded or a single TFRecord path + :param path: Can be a directory where TFRecords must be saved/loaded """ - if os.path.isdir(path) or not os.path.exists(path): - self.dirpath = path - os.makedirs(self.dirpath, exist_ok=True) - self.tfrecords_pattern_path = os.path.join(self.dirpath, "*.records") - else: - self.dirpath = os.path.dirname(path) - self.tfrecords_pattern_path = path + self.dirpath = path + os.makedirs(self.dirpath, exist_ok=True) self.output_types_file = os.path.join(self.dirpath, "output_types.json") self.output_shape_file = os.path.join(self.dirpath, "output_shape.json") self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None @@ -102,7 +97,7 @@ class TFRecords: @staticmethod def save(data, filepath): """ - Save data to pickle format. + Save data to JSON format. :param data: Data to save json format :param filepath: Output file name """ @@ -113,7 +108,7 @@ class TFRecords: @staticmethod def load(filepath): """ - Return data from pickle format. + Return data from JSON format. :param filepath: Input file name """ with open(filepath, 'r') as file: @@ -172,15 +167,15 @@ class TFRecords: # TODO: to be investigated : # 1/ num_parallel_reads useful ? I/O bottleneck of not ? # 2/ num_parallel_calls=tf.data.experimental.AUTOTUNE useful ? - # 3/ shuffle or not shuffle ? - matching_files = glob.glob(self.tfrecords_pattern_path) - logging.info('Searching TFRecords in %s...', self.tfrecords_pattern_path) + tfrecords_pattern_path = os.path.join(self.dirpath, "*.records") + matching_files = glob.glob(tfrecords_pattern_path) + logging.info('Searching TFRecords in %s...', tfrecords_pattern_path) logging.info('Number of matching TFRecords: %s', len(matching_files)) matching_files = matching_files[:n_workers * (len(matching_files) // n_workers)] # files multiple of workers nb_matching_files = len(matching_files) if nb_matching_files == 0: - raise Exception("At least one worker has no TFRecord file in {}. Please ensure that the number of TFRecord " - "files is greater or equal than the number of workers!".format(self.tfrecords_pattern_path)) + raise Exception(f"At least one worker has no TFRecord file in {tfrecords_pattern_path}. Please ensure that " + "the number of TFRecord files is greater or equal than the number of workers!") logging.info('Reducing number of records to : %s', nb_matching_files) dataset = tf.data.TFRecordDataset(matching_files) # , num_parallel_reads=2) # interleaves reads from xxx files dataset = dataset.with_options(options) # uses data as soon as it streams in, rather than in its original order @@ -189,21 +184,5 @@ class TFRecords: dataset = dataset.shuffle(buffer_size=shuffle_buffer_size) dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) - # TODO voir si on met le prefetch avant le batch cf https://keras.io/examples/keras_recipes/tfrecord/ return dataset - - def read_one_sample(self, target_keys): - """ - Read one tfrecord file matching with pattern and convert data to tensorflow dataset. - :param target_key: Key of the target, e.g. 's2_out' - """ - matching_files = glob.glob(self.tfrecords_pattern_path) - one_file = matching_files[0] - parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) - dataset = tf.data.TFRecordDataset(one_file) - dataset = dataset.map(parse) - dataset = dataset.batch(1) - - sample = iter(dataset).get_next() - return sample -- GitLab From 2076b5b6c7bf7ca35eff187224df2fa09a18b862 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Sun, 24 Apr 2022 21:40:00 +0200 Subject: [PATCH 139/154] REFAC: use f string --- otbtf/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/otbtf/utils.py b/otbtf/utils.py index 5694772b..729f7715 100644 --- a/otbtf/utils.py +++ b/otbtf/utils.py @@ -34,7 +34,7 @@ def gdal_open(filename): """ gdal_ds = gdal.Open(filename) if not gdal_ds: - raise Exception("Unable to open file {}".format(filename)) + raise Exception(f"Unable to open file {filename}") return gdal_ds -- GitLab From 31ac638d4ede12ade89fd2cd3b71c118d054b1e4 Mon Sep 17 00:00:00 2001 From: Vincent Delbar <vincent.delbar@latelescop.fr> Date: Mon, 25 Apr 2022 18:39:31 +0200 Subject: [PATCH 140/154] FIX: tfrecords dtype is forced to float32 --- otbtf/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/otbtf/utils.py b/otbtf/utils.py index 729f7715..069638a5 100644 --- a/otbtf/utils.py +++ b/otbtf/utils.py @@ -38,12 +38,13 @@ def gdal_open(filename): return gdal_ds -def read_as_np_arr(gdal_ds, as_patches=True): +def read_as_np_arr(gdal_ds, as_patches=True, dtype=None): """ Read a GDAL raster as numpy array :param gdal_ds: a GDAL dataset instance :param as_patches: if True, the returned numpy array has the following shape (n, psz_x, psz_x, nb_channels). If False, the shape is (1, psz_y, psz_x, nb_channels) + :param dtype: if not None array dtype will be cast to given numpy data type (np.float32, np.uint16...) :return: Numpy array of dim 4 """ buffer = gdal_ds.ReadAsArray() @@ -56,4 +57,9 @@ def read_as_np_arr(gdal_ds, as_patches=True): else: n_elems = 1 size_y = gdal_ds.RasterYSize - return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))) + + buffer = buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)) + if dtype is not None: + buffer = buffer.astype(dtype) + + return buffer -- GitLab From 60851101f4d2f80cad31df9bc160e81572d70f0d Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Mon, 9 May 2022 11:21:06 +0200 Subject: [PATCH 141/154] CI: test SR4RS fix --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 87964c83..4f302884 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -118,7 +118,7 @@ sr4rs: - wget -O sr4rs_data.zip https://nextcloud.inrae.fr/s/kDms9JrRMQE2Q5z/download - unzip -o sr4rs_data.zip - rm -rf sr4rs - - git clone https://github.com/remicres/sr4rs.git + - git clone -b 44-cast_float_input https://github.com/remicres/sr4rs.git - export PYTHONPATH=$PYTHONPATH:$PWD/sr4rs - python -m pytest --junitxml=$ARTIFACT_TEST_DIR/report_sr4rs.xml $OTBTF_SRC/test/sr4rs_unittest.py -- GitLab From bfe6a3ebb251975ba97da6fa8d0807b086154ce2 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Mon, 9 May 2022 15:05:00 +0200 Subject: [PATCH 142/154] WIP: fixed SR4RS repo --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 4f302884..87964c83 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -118,7 +118,7 @@ sr4rs: - wget -O sr4rs_data.zip https://nextcloud.inrae.fr/s/kDms9JrRMQE2Q5z/download - unzip -o sr4rs_data.zip - rm -rf sr4rs - - git clone -b 44-cast_float_input https://github.com/remicres/sr4rs.git + - git clone https://github.com/remicres/sr4rs.git - export PYTHONPATH=$PYTHONPATH:$PWD/sr4rs - python -m pytest --junitxml=$ARTIFACT_TEST_DIR/report_sr4rs.xml $OTBTF_SRC/test/sr4rs_unittest.py -- GitLab From dcc19a3574b33550705642b9dad746fad81c84f1 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Mon, 9 May 2022 19:38:56 +0200 Subject: [PATCH 143/154] DOC: update docker images section --- .gitlab-ci.yml | 48 +++++++++++++++++++++++++++++++++++++++++++----- doc/DOCKERUSE.md | 5 +++-- 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 87964c83..78a1fe4e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -31,11 +31,43 @@ docker image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY timeout: 10 hours script: - - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME || - - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || - > - DOCKER_BUILDKIT=1 docker build + docker build + --target otbtf-base --network="host" + --cache-from $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test + --cache-from $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME + --tag $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME + --build-arg OTBTESTS="true" + --build-arg KEEP_SRC_OTB="true" + --build-arg BZL_CONFIGS="" + --build-arg BASE_IMG="ubuntu:20.04" + --build-arg BUILDKIT_INLINE_CACHE=1 + . + - docker push $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME + - > + docker build + --target builder + --network="host" + --cache-from $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test + --cache-from $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME + --cache-from $CI_REGISTRY_IMAGE:builder-cpu-basic-test + --cache-from $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME + --tag $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME + --build-arg OTBTESTS="true" + --build-arg KEEP_SRC_OTB="true" + --build-arg BZL_CONFIGS="" + --build-arg BASE_IMG="ubuntu:20.04" + --build-arg BUILDKIT_INLINE_CACHE=1 + . + - docker push $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME + - > + docker build + --network="host" + --cache-from $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test + --cache-from $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME + --cache-from $CI_REGISTRY_IMAGE:builder-cpu-basic-test + --cache-from $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME --cache-from $CI_REGISTRY_IMAGE:cpu-basic-test --cache-from $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME @@ -43,9 +75,9 @@ docker image: --build-arg KEEP_SRC_OTB="true" --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" - --build-arg BUILDKIT_INLINE_CACHE=1 + --build-arg BUILDKIT_INLINE_CACHE=1 . - - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME + - docker push $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME .static_analysis_base: stage: Static Analysis @@ -127,6 +159,12 @@ deploy: only: - develop script: + - docker pull $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME + - docker pull $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME + - docker tag $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test + - docker tag $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:builder-cpu-basic-test - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test + - docker push $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test + - docker push $CI_REGISTRY_IMAGE:builder-cpu-basic-test - docker push $CI_REGISTRY_IMAGE:cpu-basic-test diff --git a/doc/DOCKERUSE.md b/doc/DOCKERUSE.md index 69382c96..9c100896 100644 --- a/doc/DOCKERUSE.md +++ b/doc/DOCKERUSE.md @@ -26,8 +26,9 @@ Here is the list of OTBTF docker images hosted on [dockerhub](https://hub.docker | **mdl4eo/otbtf3.0:gpu-dev** | Ubuntu Focal | r2.5 | 7.4.0 | GPU (dev) | yes | 5.2,6.1,7.0,7.5,8.6| | **mdl4eo/otbtf3.1:cpu-basic** | Ubuntu Focal | r2.8 | 7.4.0 | CPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| | **mdl4eo/otbtf3.1:cpu-basic-dev** | Ubuntu Focal | r2.8 | 7.4.0 | CPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.1:gpu-basic** | Ubuntu Focal | r2.8 | 7.4.0 | GPU | yes | 5.2,6.1,7.0,7.5,8.6| -| **mdl4eo/otbtf3.1:gpu** | Ubuntu Focal | r2.8 | 7.4.0 | GPU, no optimization | yes | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf3.1:gpu-basic** | Ubuntu Focal | r2.8 | 7.4.0 | GPU, no optimization | no | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf3.1:gpu-basic-dev** | Ubuntu Focal | r2.8 | 7.4.0 | GPU, no optimization (dev) | yes | 5.2,6.1,7.0,7.5,8.6| +| **mdl4eo/otbtf3.1:gpu** | Ubuntu Focal | r2.8 | 7.4.0 | GPU | no | 5.2,6.1,7.0,7.5,8.6| | **mdl4eo/otbtf3.1:gpu-dev** | Ubuntu Focal | r2.8 | 7.4.0 | GPU (dev) | yes | 5.2,6.1,7.0,7.5,8.6| - `cpu` tagged docker images are compiled for CPU usage only. -- GitLab From 50190566b07d8862f9eafd743fc8c227e7214661 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Mon, 9 May 2022 19:48:18 +0200 Subject: [PATCH 144/154] ADD: docker pull --- .gitlab-ci.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 78a1fe4e..af1f6ffd 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -31,6 +31,12 @@ docker image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY timeout: 10 hours script: + - docker pull $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME || + - docker pull $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME || + - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME || + - docker pull $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test || + - docker pull $CI_REGISTRY_IMAGE:builder-cpu-basic-test || + - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || - > docker build --target otbtf-base -- GitLab From 62b4485b2518a3c9e418b6d79b7fd8a0b78d36f4 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Mon, 9 May 2022 19:56:14 +0200 Subject: [PATCH 145/154] ADD: docker pull --- .gitlab-ci.yml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index af1f6ffd..698795f2 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -44,11 +44,7 @@ docker image: --cache-from $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test --cache-from $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME --tag $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME - --build-arg OTBTESTS="true" - --build-arg KEEP_SRC_OTB="true" - --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" - --build-arg BUILDKIT_INLINE_CACHE=1 . - docker push $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME - > @@ -84,6 +80,13 @@ docker image: --build-arg BUILDKIT_INLINE_CACHE=1 . - docker push $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME + after_script: + - docker tag $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test + - docker tag $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:builder-cpu-basic-test + - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test + - docker push $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test + - docker push $CI_REGISTRY_IMAGE:builder-cpu-basic-test + - docker push $CI_REGISTRY_IMAGE:cpu-basic-test .static_analysis_base: stage: Static Analysis @@ -165,6 +168,7 @@ deploy: only: - develop script: + - echo "Shippping!" - docker pull $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME - docker pull $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME -- GitLab From e2941df8e787c4a7e45cb2b505b14d973444d7c6 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Mon, 9 May 2022 21:03:24 +0200 Subject: [PATCH 146/154] ADD: docker pull --- .gitlab-ci.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 698795f2..ff4d0f75 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -6,7 +6,9 @@ variables: OTB_TEST_DIR: $OTB_BUILD/Testing/Temporary # OTB testing directory ARTIFACT_TEST_DIR: $CI_PROJECT_DIR/testing CRC_BOOK_TMP: /tmp/crc_book_tests_tmp - + DOCKER_BUILDKIT: 1 + DOCKER_DRIVER: overlay2 + workflow: rules: - if: $CI_MERGE_REQUEST_ID # Execute jobs in merge request context @@ -45,6 +47,7 @@ docker image: --cache-from $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME --tag $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME --build-arg BASE_IMG="ubuntu:20.04" + --build-arg BUILDKIT_INLINE_CACHE=1 . - docker push $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME - > -- GitLab From 4888e055357a4f7a5a3c48bccdb2eea60fb6627d Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Mon, 9 May 2022 21:30:29 +0200 Subject: [PATCH 147/154] ADD: docker pull --- .gitlab-ci.yml | 52 +++++++++++++++----------------------------------- 1 file changed, 15 insertions(+), 37 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index ff4d0f75..41881fb1 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -8,6 +8,8 @@ variables: CRC_BOOK_TMP: /tmp/crc_book_tests_tmp DOCKER_BUILDKIT: 1 DOCKER_DRIVER: overlay2 + CACHE_IMAGE_BASE: $CI_REGISTRY_IMAGE:otbtf-base + CACHE_IMAGE_BUILDER: $CI_REGISTRY_IMAGE:builder workflow: rules: @@ -33,47 +35,33 @@ docker image: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY timeout: 10 hours script: - - docker pull $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME || - - docker pull $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME || - - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME || - - docker pull $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test || - - docker pull $CI_REGISTRY_IMAGE:builder-cpu-basic-test || - - docker pull $CI_REGISTRY_IMAGE:cpu-basic-test || - > docker build --target otbtf-base --network="host" - --cache-from $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test - --cache-from $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME - --tag $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME + --cache-from $CACHE_IMAGE_BASE + --tag $CACHE_IMAGE_BASE --build-arg BASE_IMG="ubuntu:20.04" --build-arg BUILDKIT_INLINE_CACHE=1 - . - - docker push $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME + "." - > docker build --target builder --network="host" - --cache-from $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test - --cache-from $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME - --cache-from $CI_REGISTRY_IMAGE:builder-cpu-basic-test - --cache-from $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME - --tag $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME + --cache-from $CACHE_IMAGE_BASE + --cache-from $CACHE_IMAGE_BUILDER + --tag $CACHE_IMAGE_BUILDER --build-arg OTBTESTS="true" --build-arg KEEP_SRC_OTB="true" --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" --build-arg BUILDKIT_INLINE_CACHE=1 - . - - docker push $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME + "." - > docker build --network="host" - --cache-from $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test - --cache-from $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME - --cache-from $CI_REGISTRY_IMAGE:builder-cpu-basic-test - --cache-from $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME - --cache-from $CI_REGISTRY_IMAGE:cpu-basic-test + --cache-from $CACHE_IMAGE_BASE + --cache-from $CACHE_IMAGE_BUILDER --cache-from $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME --build-arg OTBTESTS="true" @@ -81,15 +69,11 @@ docker image: --build-arg BZL_CONFIGS="" --build-arg BASE_IMG="ubuntu:20.04" --build-arg BUILDKIT_INLINE_CACHE=1 - . - - docker push $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME + "." after_script: - - docker tag $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test - - docker tag $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:builder-cpu-basic-test - - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test - - docker push $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test - - docker push $CI_REGISTRY_IMAGE:builder-cpu-basic-test - - docker push $CI_REGISTRY_IMAGE:cpu-basic-test + - docker push $CACHE_IMAGE_BASE + - docker push $CACHE_IMAGE_BUILDER + - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME .static_analysis_base: stage: Static Analysis @@ -172,12 +156,6 @@ deploy: - develop script: - echo "Shippping!" - - docker pull $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME - - docker pull $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - - docker tag $CI_REGISTRY_IMAGE:otbtf-base-$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test - - docker tag $CI_REGISTRY_IMAGE:builder-$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:builder-cpu-basic-test - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test - - docker push $CI_REGISTRY_IMAGE:otbtf-base-cpu-basic-test - - docker push $CI_REGISTRY_IMAGE:builder-cpu-basic-test - docker push $CI_REGISTRY_IMAGE:cpu-basic-test -- GitLab From 2e7ab53ef002563e6c6c6ec21ff1498648f45604 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Mon, 9 May 2022 22:59:10 +0200 Subject: [PATCH 148/154] ADD: ports --- .gitlab-ci.yml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 41881fb1..5dc5a7c7 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -31,6 +31,8 @@ docker image: - develop services: - name: docker:dind + ports: + - "9090:9090" before_script: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY timeout: 10 hours @@ -120,10 +122,6 @@ ctest: .applications_test_base: extends: .tests_base stage: Applications Test - rules: - # Only for MR targeting 'develop' and 'master' branches because applications tests are slow - - if: $CI_MERGE_REQUEST_ID && $CI_MERGE_REQUEST_TARGET_BRANCH_NAME == 'develop' - - if: $CI_MERGE_REQUEST_ID && $CI_MERGE_REQUEST_TARGET_BRANCH_NAME == 'master' before_script: - pip3 install pytest pytest-cov pytest-order - mkdir -p $ARTIFACT_TEST_DIR -- GitLab From 94f018008fdcc1a955719782fd5ef385da97c309 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Mon, 9 May 2022 23:00:15 +0200 Subject: [PATCH 149/154] ADD: ports --- .gitlab-ci.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 5dc5a7c7..e9c2ea92 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -31,8 +31,6 @@ docker image: - develop services: - name: docker:dind - ports: - - "9090:9090" before_script: - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY timeout: 10 hours -- GitLab From 7c0c47b7d58ddc2e0106537c080311a1a547a627 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Tue, 10 May 2022 10:36:47 +0200 Subject: [PATCH 150/154] ADD: build all docker images --- .gitlab-ci.yml | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index e9c2ea92..d1043491 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -63,6 +63,7 @@ docker image: --cache-from $CACHE_IMAGE_BASE --cache-from $CACHE_IMAGE_BUILDER --cache-from $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME + --cache-from $CI_REGISTRY_IMAGE:cpu-basic-dev-testing --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME --build-arg OTBTESTS="true" --build-arg KEEP_SRC_OTB="true" @@ -153,5 +154,18 @@ deploy: script: - echo "Shippping!" - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME - - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-test - - docker push $CI_REGISTRY_IMAGE:cpu-basic-test + - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:cpu-basic-dev-testing + - docker push $CI_REGISTRY_IMAGE:cpu-basic-dev-testing + - docker build --network='host' --tag $CI_REGISTRY_IMAGE:cpu-basic --build-arg BASE_IMG=ubuntu:20.04 --build-arg BZL_CONFIGS="" . # cpu-basic + - docker push $CI_REGISTRY_IMAGE:cpu-basic + - docker build --network='host' --tag $CI_REGISTRY_IMAGE:cpu-basic-dev --build-arg BASE_IMG=ubuntu:20.04 --build-arg BZL_CONFIGS="" --build-arg KEEP_SRC_OTB=true . # cpu-basic-dev + - docker push $CI_REGISTRY_IMAGE:cpu-basic-dev + - docker build --network='host' --tag $CI_REGISTRY_IMAGE:gpu --build-arg BASE_IMG=nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 . # gpu + - docker push $CI_REGISTRY_IMAGE:gpu + - docker build --network='host' --tag mdl4eo/otbtf${VER}:gpu-dev --build-arg BASE_IMG=nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 --build-arg KEEP_SRC_OTB=true . # gpu-dev + - docker push $CI_REGISTRY_IMAGE:gpu-dev + - docker build --network='host' --tag $CI_REGISTRY_IMAGE:gpu-basic --build-arg BASE_IMG=nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 --build-arg BZL_CONFIGS="" . # gpu-basic + - docker push $CI_REGISTRY_IMAGE:gpu-basic + - docker build --network='host' --tag mdl4eo/otbtf${VER}:gpu-basic-dev --build-arg BZL_CONFIGS="" --build-arg BASE_IMG=nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 --build-arg KEEP_SRC_OTB=true . # gpu-basic-dev + - docker push $CI_REGISTRY_IMAGE:gpu-basic-dev + -- GitLab From 0f4b3c74820184c14668f42ce919118ef888822e Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Tue, 10 May 2022 12:21:22 +0200 Subject: [PATCH 151/154] ADD: rule for build all --- .gitlab-ci.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d1043491..19975de5 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -147,10 +147,11 @@ sr4rs: - export PYTHONPATH=$PYTHONPATH:$PWD/sr4rs - python -m pytest --junitxml=$ARTIFACT_TEST_DIR/report_sr4rs.xml $OTBTF_SRC/test/sr4rs_unittest.py -deploy: +buildall: stage: Ship - only: - - develop + rules: + - if: $CI_COMMIT_BRANCH == 'master' && $CI_PIPELINE_SOURCE == 'merge_request_event' + - if: $CI_COMMIT_BRANCH == 'develop' && $CI_PIPELINE_SOURCE == 'merge_request_event' script: - echo "Shippping!" - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME -- GitLab From 17fdf6bf22d96d967b8989227c4efe2db86a881a Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Tue, 10 May 2022 15:04:42 +0200 Subject: [PATCH 152/154] ADD: deploy when commit to develop --- .gitlab-ci.yml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 19975de5..49d62396 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -13,8 +13,8 @@ variables: workflow: rules: - - if: $CI_MERGE_REQUEST_ID # Execute jobs in merge request context - + - if: $CI_MERGE_REQUEST_ID || $CI_COMMIT_REF_NAME =~ /develop/ # Execute jobs in merge request context, or commit in develop branch + stages: - Build - Static Analysis @@ -147,11 +147,10 @@ sr4rs: - export PYTHONPATH=$PYTHONPATH:$PWD/sr4rs - python -m pytest --junitxml=$ARTIFACT_TEST_DIR/report_sr4rs.xml $OTBTF_SRC/test/sr4rs_unittest.py -buildall: +deploy: stage: Ship - rules: - - if: $CI_COMMIT_BRANCH == 'master' && $CI_PIPELINE_SOURCE == 'merge_request_event' - - if: $CI_COMMIT_BRANCH == 'develop' && $CI_PIPELINE_SOURCE == 'merge_request_event' + only: + - develop script: - echo "Shippping!" - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME -- GitLab From f38d87ec16a6cd0f04db5f7c58f30f787866cc1d Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Tue, 10 May 2022 16:53:07 +0200 Subject: [PATCH 153/154] ADD: deploy when commit to develop --- .gitlab-ci.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 49d62396..55004d00 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -151,6 +151,13 @@ deploy: stage: Ship only: - develop + tags: [godzilla] + image: docker/compose:latest + services: + - name: docker:dind + before_script: + - echo -n $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY + timeout: 10 hours script: - echo "Shippping!" - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME -- GitLab From e5c9c85911a97f3184bdbcf8a003d60b7c7251f3 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Tue, 10 May 2022 21:38:55 +0200 Subject: [PATCH 154/154] ADD: deploy when master is updated --- .gitlab-ci.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 55004d00..ee9a0e80 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -13,7 +13,7 @@ variables: workflow: rules: - - if: $CI_MERGE_REQUEST_ID || $CI_COMMIT_REF_NAME =~ /develop/ # Execute jobs in merge request context, or commit in develop branch + - if: $CI_MERGE_REQUEST_ID || $CI_COMMIT_REF_NAME =~ /master/ # Execute jobs in merge request context, or commit in master branch stages: - Build @@ -29,6 +29,7 @@ docker image: image: docker/compose:latest except: - develop + - master services: - name: docker:dind before_script: @@ -150,7 +151,7 @@ sr4rs: deploy: stage: Ship only: - - develop + - master tags: [godzilla] image: docker/compose:latest services: @@ -169,10 +170,10 @@ deploy: - docker push $CI_REGISTRY_IMAGE:cpu-basic-dev - docker build --network='host' --tag $CI_REGISTRY_IMAGE:gpu --build-arg BASE_IMG=nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 . # gpu - docker push $CI_REGISTRY_IMAGE:gpu - - docker build --network='host' --tag mdl4eo/otbtf${VER}:gpu-dev --build-arg BASE_IMG=nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 --build-arg KEEP_SRC_OTB=true . # gpu-dev + - docker build --network='host' --tag $CI_REGISTRY_IMAGE:gpu-dev --build-arg BASE_IMG=nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 --build-arg KEEP_SRC_OTB=true . # gpu-dev - docker push $CI_REGISTRY_IMAGE:gpu-dev - docker build --network='host' --tag $CI_REGISTRY_IMAGE:gpu-basic --build-arg BASE_IMG=nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 --build-arg BZL_CONFIGS="" . # gpu-basic - docker push $CI_REGISTRY_IMAGE:gpu-basic - - docker build --network='host' --tag mdl4eo/otbtf${VER}:gpu-basic-dev --build-arg BZL_CONFIGS="" --build-arg BASE_IMG=nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 --build-arg KEEP_SRC_OTB=true . # gpu-basic-dev + - docker build --network='host' --tag $CI_REGISTRY_IMAGE:gpu-basic-dev --build-arg BZL_CONFIGS="" --build-arg BASE_IMG=nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 --build-arg KEEP_SRC_OTB=true . # gpu-basic-dev - docker push $CI_REGISTRY_IMAGE:gpu-basic-dev -- GitLab