The methods mentioned above do not provide any back propagation / optimization process at all. They are handled by tff.templates.iterative process. We can get new states and metrics every time we pass in the current state and training set. In order to better customize our own optimization methods, we need to write our own tff.template.IterativeProcess method, rewrite initialize and next methods, and set our own optimization process.
data type
Federated Core provides the following types:
- Tensor type( tff.TensorType ). Objects are not limited to Python objects that represent the output of TensorFlow operations in TensorFlow calculation diagrams tf.Tensor Examples, but may also include data units that can be generated, for example, as the output of a distributed aggregation protocol. The compact representation of tensor type is dtype or dtype[shape]. For example, int32 and int32[10] are types of integers and integer vectors, respectively.
- Sequence type( tff.SequenceType ). These are equivalent to TensorFlow in TFF tf.data.Dataset The abstraction of concrete concepts. Users can use sequential elements in order and can contain complex types. The compact representation of the sequence type is T *, where T is the type of element. For example, int32 * represents a sequence of integers.
- Named tuple type( tff.StructType ). These are how TFF constructs tuples or dictionary structures with a predefined number of elements using a specified type, whether named or not. The important point is that TFF's concept of named tuples contains an abstraction equivalent to Python parameter tuples, that is, some (not all) of the element set of tuples are named elements and some are location elements. The compact representation of named tuples is < n_ 1=T_ 1, ..., n_k=T_k> Where n_k is the optional element name, T_k is the element type. For example, < int32, int32 > is a compact representation of a pair of unnamed integers, < x = float32, y = float32 > is a compact representation of a pair of floating-point numbers named X and Y (which may represent a point on the plane). Tuples can be nested or mixed with other types. For example, < x = float32, y = float32 > * may be a compact representation of a series of points.
- Function type( tff.FunctionType ). TFF is a functional programming framework, in which functions are regarded as compact representations of these functions (T - > U), where t is the parameter type and u is the result type; Alternatively, if there are no parameters (although parameterless function is an outdated concept that exists only at the Python level in most cases), it can be expressed as (- > U). For example, (int32 * - > int32) represents a function type that reduces a sequence of integers to a single integer value. First type value . Function has at most one parameter and only one result.
The following types address the distribution system aspects of TFF Computing:
- Layout type. Except for 2 text forms tff.SERVER and tff.CLIENTS (you can think of it as a constant of this type), this type has not been exposed in the public API. It is for internal use only, but will be introduced in future versions of the public API. The compact representation of this type is placement. A layout represents a collection of system participants who play a specific role. The original version was designed to solve the problem of client server computing. There are two groups of participants: client and server (the latter can be regarded as a single instance group). However, in more complex architectures, there may be other roles, such as the intermediate aggregator in multi-layer systems. This aggregator may perform different types of aggregation or use different types of data compression / decompression instead of the type used by the server or client. The main purpose of defining layout concepts is to serve as the basis for defining union types.
- Union type( tff.FederatedType ). The value of the union type is determined by a specific layout, such as tff.SERVER or tff.CLIENTS )A defined set of values hosted by system participants. The union type passes through the layout value (therefore, it is a Dependency type ), member components (the content type that each participant hosts locally), and specify whether all participants host additional parts of the same project locally_ Equal. For union types that contain values for T-type items (composed of members), if each item is hosted by group (layout) g, its compact representation is T@G Or {T}@G, set or not set all respectively_ Equal bit. {int32}@CLIENTS indicates that it contains a set of integers that may be different; {< x = float32, y = float32 > *} @ clients represents a joint data set< Weights = float32 [10,5], bias = float32 [5] > @ server represents the named tuples of weights and deviation tensors on the server. We omit curly braces, which indicates that all is set_ Equal bit.
federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS) # '{float32}@CLIENTS'
function
The language of Federated Core is a λ calculus , it provides the following programming abstractions currently exposed in the public API:
- TensorFlow calculation( tff.tf_computation ). There are some uses in TFF tff.tf_computation Decorators are packaged as TensorFlow code parts of reusable components. These codes are generally functional types, but unlike functions in TensorFlow, they can accept structured parameters or return structured results of sequence types.
# tensor computation is constricted in tff.federated_computation # should be completion by the following way: @tff.tf_computation(tff.SequenceType(tf.int32)) def add_up_integers(x): return x.reduce(np.int32(0), lambda x, y: x + y) @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS)) def get_average_temperature(client_temperatures): return tff.federated_mean(client_temperatures) # '({float32}@CLIENTS -> float32@SERVER)'
-
Inline function (tff.federated_...). This is the function library that makes up most FC API s, such as tff.federated_sum or tff.federated_broadcast , most of which represent distributed communication operators used with TFF.
-
\(\ lambda \) expression( tff.federated_computation ). The \ (\ lambda \) expression in TFF is equivalent to lambda or def in Python; It contains the name of the parameter and the body (expression) that contains the reference to the parameter.
# the biggest differnce between tf.computation and tff.federated_computation is the placement @tff.tf_computation(tf.float32) def add_half(x): return tf.add(x, 0.5) @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS)) def add_half_on_clients(x): return tff.federated_map(add_half, x)
Work flow
A typical FL study consists of three main logics
- TF fragments at the individual level, such as tf.function, can run independently locally, such as the training code of the client
- TFF arranges the code to help pass tf.function at the personal level through tff_computation is integrated together, and through TFF. Federated, which contains it_ Broadcast and tff.federated_mean orchestrating
- External driver code, such as customer selection.
# data preparation import nest_asyncio nest_asyncio.apply() import tensorflow as tf import tensorflow_federated as tff emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data() NUM_CLIENTS = 10 BATCH_SIZE = 20 def preprocess(dataset): def batch_format_fn(element): """Flatten a batch of EMNIST data and return a (features, label) tuple.""" return (tf.reshape(element['pixels'], [-1, 784]), tf.reshape(element['label'], [-1, 1])) return dataset.batch(BATCH_SIZE).map(batch_format_fn) client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS] federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x)) for x in client_ids ]
# model preparation def create_keras_model(): initializer = tf.keras.initializers.GlorotNormal(seed=0) return tf.keras.models.Sequential([ tf.keras.layers.Input(shape=(784,)), tf.keras.layers.Dense(10, kernel_initializer=initializer), tf.keras.layers.Softmax(), ]) def model_fn(): keras_model = create_keras_model() return tff.learning.from_keras_model( keras_model, input_spec=federated_train_data[0].element_spec, loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
To build our own Federated Learning algorithm, there are four main components:
- A server-to-clients broadcast step
- A local client update step
- A client-to-server upload step
- A server update step
Meanwhile, we should rewrite initialize and next functions.
Method_1
Local training
Local training does not require tff participation
# step 2 local training # return client model weights @tf.function def client_update(model, dataset, server_weights, client_optimizer): client_weights = model.trainable_variables # clone server_weights, which is exactly state meaning in the previous code. tf.nest.map_structure(lambda x, y: x.assign(y), client_weights, server_weights) # optimization for batch in dataset: with tf.GradientTape() as tape: outputs = model.forward_pass(batch) grads = tape.gradient(outputs.loss, client_weights) grad_and_vars = zip(grads, client_weights) client_optimizer.apply_gradients(grad_and_vars) # update return client_weights
The input parameters are model, dataset and server_weights,client_optimizer, why are there so many parameters? The reason is that tf.function does not involve any data placement information, and the placement part is left to tff for processing.
Server update
Like client-side updates, server-side updates do not require tff participation
# step4 @tf.function def server_update(model, mean_client_weights): model_weights = model.trainable_variables tf.nest.map_structure(lambda x,y: x.assign(y), model_weights, mean_client_weights) return model_weights
TFF snippet
Now you need tff to integrate different placement data and rewrite the two methods of tff. Templates. Iterative process.
# initialize method @tff.tf_computation def server_init(): model = model_fn() return model.trainable_variables @tff.federated_computation def initialize_fn(): return tff.federated_value(server_init(), tff.SERVER) # A federated value with the given placement placement, and the member constituent value equal at all locations.
whimsy_model = model_fn() tf_dataset_type = tff.SequenceType(whimsy_model.input_spec) # inpute specification model_weights_type = server_init.type_signature.result # output specification
# there are multiple sources data and should use tff.tf_computation decoration @tff.tf_computation(tf_dataset_type, model_weights_type) def client_update_fn(tf_dataset, server_weights): model = model_fn() client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) return client_update(model, tf_dataset, server_weights, client_optimizer) @tff.tf_computation(model_weights_type) def server_update_fn(mean_client_weights): model = model_fn() return server_update(model, mean_client_weights) federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER) federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS) # rewrite next function # state is server_weights. @tff.federated_computation(federated_server_type, federated_dataset_type) def next_fn(server_weights, federated_dataset): # step1. broadcast server_weights_at_client = tff.federated_broadcast(server_weights) # step2. local update client_weights = tff.federated_map( client_update_fn, (federated_dataset, server_weights_at_client)) # step3. uploading mean_client_weights = tff.federated_mean(client_weights) # step4. server update server_weights = tff.federated_map(server_update_fn, mean_client_weights) return server_weights federated_algorithm = tff.templates.IterativeProcess( initialize_fn=initialize_fn, next_fn=next_fn )
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients() central_emnist_test = preprocess(central_emnist_test) def evaluate(server_state): keras_model = create_keras_model() keras_model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()] ) keras_model.set_weights(server_state) keras_model.evaluate(central_emnist_test) server_state = federated_algorithm.initialize() evaluate(server_state) for round in range(15): server_state = federated_algorithm.next(server_state, federated_train_data) evaluate(server_state)
Method_2
In the second method, an optimizer from tff.leraning.optimizers will supersede the previous one, which has initialize(<Tensorspec>) and next functions.
TF snippet
@tf.function def client_update(model, dataset, server_weights, optimizer): client_weights = model.trainable_weights tf.nest.map_structure(lambda x, y: x.assign(y), client_weights, server_weights) trainable_tensor_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights) optimizer_state = client_optimizer.initialize(trainable_tensor_specs) for batch in iter(dataset): with tf.GradientTape() as tape: output = model.forward_pass(batch) grads = tape.gradient(outputs.loss, client_weights) optimizer_state, update_weights = client_optimizer.next( optimizer_state, client_weights, grads) tf.nest.map_structure(lambda a, b: a.assign(b), client_weights, update_weights) return tf.nest.map_structure(tf.subtract, client_weights, server_weights) # return the cumulative gradient # contanier, collecting server weights and server optimizer state. @attr.s(eq=False, frozen=True, slots=True) class ServerState(object): trainable_weights = attr.ib() optimizer_state = attr.ib() @tf.function def server_update(server_state, mean_model_delta, server_optimizer): negative_weights_delta = tf.nest.map_structure( lambda w: -1.0 * w, mean_model_delta) new_optimizer_state, updated_weights = server_optimizer.next( server_state.optimizer_state, server_state.trainable_weights, negative_weights_delta) return tff.structure.update_struct( server_state, trainable_weights = updated_weights, optimizer_state = new_optimizer_state)
TFF snippet
server_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.05, momentum=0.9) client_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.01) @tff.tf_computation def server_init(): model = model_fn() trainable_tensor_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables) optimizer_state = server_optimizer.initialize(trainable_tensor_specs) return ServerState( trainable_weights=model.trainable_variables, optimizer_state=optimizer_state) @tff.tff_computation def server_init_tff(): return tff.federated_value(server_init(), tff.SERVER) server_state_type = server_init.type_signature.result trainable_weights_type = server_state_type.trainable_weights @tff.tf_computation(server_state_type, trainable_weights_type) def server_update_fn(server_state, model_delta): return server_update(server_state, model_delta, server_optimizer) whimsy_model = model_fn() tf_dataset_type = tff.SequenceType(whimsy_model.input_spec) @tff.tf_computation(tf_dataset_type, trainable_weights_type) def client_update_fn(dataset, server_weights): model = model_fn() return client_update(model, dataset, server_weights, client_optimizer) federated_server_type = tff.FederatedType(server_state_type, tff.SERVER) federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS) @tff.federated_computation(federated_server_type, federated_dataset_type) def run_one_round(server_state, federated_dataset): server_weights_at_client = tff.federated_broadcast( server_state.trainable_weights) model_deltas = tff.federated_map( client_update_fn, (federated_dataset, server_weights_at_client)) mean_model_delta = tff.federated_mean(model_deltas) server_state = tff.federated_map( server_update_fn, (server_state, mean_model_delta)) return server_state fedavg_process = tff.templates.IterativeProcess( initialize_fn=server_init_tff, next_fn=run_one_round)
Summary
The process of customizing our own tff.template.IterativeProcess class:
- Firstly, regardless of placement constraint, you should complete the Tensorflow code to fulfill the client update and server update function. Usually, the input parameters for the client update function should include model, dataset, server_weights and optimizer and the output should be the cumulative grads or the new client model trainable variables. The input of the server update is rather simple, the current server state and the new aggregated changes and its output is the new server state. According to your definition, the serve state can be the model trainable variables or contains other items. Both of this two functions are decorated by tf.function;
- Secondly, server_update_fn, client_update_fn and server_init_fn should be completed and all of them are decorated by tff.tf_computation. The decoration shows that the input parameters should be placed in the same position. In the server_init_fn, the output should be a new state. In the client_update_fn, the input parameters are dataset and server_weights(Note, server_weights are the duplication and placed in the tff.CLIENTS by the tff.federated_broadcast function) and it will call the previous client update function. In the server_update_fn, the input parameters are server_state and the cumulative changes(Note, cumulative changes are aggregated by the tff.federated_mean function and placed in tff.SERVER) and call the previous server update function;
- Thirdly, server_init_tff and next_fn will be created and both of them are decorated by tff.federated_computation to solve the placement issues. In the server_init_tff function, it will place the value, output of the server_init function, to the tff.SERVER by the tff.federated_value function. In the next_fn, four steps in the workflow will be completed.