Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions baselines/diabetic_retinopathy_detection/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def checkpoint_trained_model(
accum_train_time=checkpoint_data.accumulated_train_time),
)
if checkpoint_data.fixed_model_states is not None:
tree["states"] = checkpoint_data.fixed_model_states
tree["states"] = checkpoint_data.fixed_model_states # pyrefly: ignore[bad-assignment]
save_checkpoint(tree, path, step_for_copy)


Expand All @@ -216,7 +216,7 @@ def _flatten_jax_params_dict(d: Params, parent_key: str = "",
for k, v in d.items():
path = parent_key + sep + k if parent_key else k
if isinstance(v, collections.abc.Mapping):
items.extend(_flatten_jax_params_dict(v, path, sep=sep).items())
items.extend(_flatten_jax_params_dict(v, path, sep=sep).items()) # pyrefly: ignore[bad-argument-type]
else:
items.append((path, v))

Expand Down Expand Up @@ -411,7 +411,7 @@ def maybe_load_checkpoint(train_loop_rngs: jnp.ndarray,

checkpoint_tree = {"opt": init_optimizer, "extra": checkpoint_extra}
if init_fixed_model_states is not None:
checkpoint_tree["states"] = init_fixed_model_states
checkpoint_tree["states"] = init_fixed_model_states # pyrefly: ignore[bad-assignment]
checkpoint = load_checkpoint(checkpoint_tree, resume_checkpoint_path)
optimizer, checkpoint_extra = checkpoint["opt"], checkpoint["extra"]
fixed_model_states = checkpoint.get("states", None)
Expand Down Expand Up @@ -442,7 +442,7 @@ def maybe_load_checkpoint(train_loop_rngs: jnp.ndarray,
return CheckpointData(
optimizer=optimizer,
fixed_model_states=fixed_model_states,
train_loop_rngs=checkpoint_extra["rngs_loop"],
train_loop_rngs=checkpoint_extra["rngs_loop"], # pyrefly: ignore[bad-argument-type]
accumulated_train_time=checkpoint_extra["accum_train_time"])


Expand Down
2 changes: 1 addition & 1 deletion baselines/diabetic_retinopathy_detection/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def get_data(
dataset_builder = _get_dataset_builder(dataset, data_dir)

if rng is not None:
rng = jax.random.fold_in(rng, process_index) # Derive RNG for this process.
rng = jax.random.fold_in(rng, process_index) # Derive RNG for this process. # pyrefly: ignore[bad-argument-type]

process_split = _get_process_split(
split,
Expand Down
4 changes: 2 additions & 2 deletions baselines/diabetic_retinopathy_detection/preprocess_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __call__(self, features: Features) -> Features:

# Perform binarization using given threshold
labels = features["label"]
labels = tf.cast(labels > highest_negative_class, tf.int32)
labels = tf.cast(labels > highest_negative_class, tf.int32) # pyrefly: ignore[unsupported-operation]
features[self.key_result or self.key] = decoded_image
features["labels"] = labels
del features["label"]
Expand Down Expand Up @@ -413,7 +413,7 @@ def __call__(self, features: Features) -> Features:
# than using tf.one_hot followed by tf.reduce_max; we tested.
labels = features[self.key]
if labels.shape.rank > 0 and self.multi: # pytype: disable=attribute-error
x = tf.scatter_nd(labels[:, None], tf.ones(tf.shape(labels)[0]),
x = tf.scatter_nd(labels[:, None], tf.ones(tf.shape(labels)[0]), # pyrefly: ignore[bad-index]
(self.depth,))
x = tf.clip_by_value(x, 0, 1) * (self.on - self.off) + self.off
else:
Expand Down
6 changes: 3 additions & 3 deletions baselines/diabetic_retinopathy_detection/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def hms(s):
return f"{h:.0f}h{m:.0f}m" # Seconds intentionally omitted.

# Progress note with "global" full-program average timings
dt = now - self.start_time # Time since process start.
dt = now - self.start_time # Time since process start. # pyrefly: ignore[unsupported-operation]
steps_done = step - self.first_step
steps_todo = self.total_steps - step
self.note = f"Steps:{step}/{self.total_steps} [{step/self.total_steps:.1%}]"
Expand All @@ -270,7 +270,7 @@ def hms(s):
timing_measurements = {}

# Measurement with micro-timings of current training steps speed.
dt = now - self.prev_time - self.paused_time # Time between ticks.
dt = now - self.prev_time - self.paused_time # Time between ticks. # pyrefly: ignore[unsupported-operation]
ds = step - self.prev_step # Steps between ticks.
ncores = jax.device_count() # Global device count.
timing_measurements["img/sec/core"] = self.global_bs * ds / dt / ncores
Expand All @@ -294,5 +294,5 @@ def pause(self):

def resume(self):
"""Resumes the time measurement."""
self.paused_time += time.time() - self.pause_start
self.paused_time += time.time() - self.pause_start # pyrefly: ignore[bad-assignment, unsupported-operation]
self.pause_start = None
6 changes: 3 additions & 3 deletions baselines/drug_cardiotoxicity/augmentation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _perturb_edges_helper(
perturbation_values_for_pair_mask = drop_values_for_pair_mask
else:
edges_to_perturb = tf.concat(
(idx_edges_to_drop, idx_bidirectional_edges_to_add), axis=0)
(idx_edges_to_drop, idx_bidirectional_edges_to_add), axis=0) # pyrefly: ignore[unbound-name]
add_values_for_pair_mask = tf.ones(
tf.shape(idx_bidirectional_edges_to_add)[0])
perturbation_values_for_pair_mask = tf.concat(
Expand All @@ -344,11 +344,11 @@ def _perturb_edges_helper(
else:
if self.initialize_edge_features_randomly:
add_values_for_pairs = tf.random.uniform(
(tf.shape(idx_bidirectional_edges_to_add)[0],
(tf.shape(idx_bidirectional_edges_to_add)[0], # pyrefly: ignore[unbound-name]
tf.shape(pairs)[-1]))
else:
add_values_for_pairs = tf.concat(
(features_of_edges_to_drop, features_of_edges_to_drop), axis=0)
(features_of_edges_to_drop, features_of_edges_to_drop), axis=0) # pyrefly: ignore[unbound-name]
perturbation_values_for_pairs = tf.concat(
(drop_values_for_pairs, add_values_for_pairs), axis=0)

Expand Down
12 changes: 6 additions & 6 deletions baselines/drug_cardiotoxicity/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def step_fn(inputs):
negative_log_likelihood = tf.reduce_mean(
tf.keras.losses.categorical_crossentropy(labels, probs))

metrics[f'{dataset_name}/negative_log_likelihood'].update_state(
metrics[f'{dataset_name}/negative_log_likelihood'].update_state( # pyrefly: ignore[missing-argument]
negative_log_likelihood)
metrics[f'{dataset_name}/accuracy'].update_state(labels, probs)
metrics[f'{dataset_name}/roc_auc'].update_state(
Expand All @@ -275,8 +275,8 @@ def step_fn(inputs):
logging.info('Starting to run epoch: %s', epoch)
train_step(train_iterator)

current_step = (epoch + 1) * params.steps_per_epoch
max_steps = params.steps_per_epoch * params.num_epochs
current_step = (epoch + 1) * params.steps_per_epoch # pyrefly: ignore[unsupported-operation]
max_steps = params.steps_per_epoch * params.num_epochs # pyrefly: ignore[unsupported-operation]
time_elapsed = time.time() - start_time
steps_per_sec = float(current_step) / time_elapsed
eta_seconds = (max_steps - current_step) / steps_per_sec
Expand Down Expand Up @@ -325,11 +325,11 @@ def main(argv: Sequence[str]):
strategy = tf.distribute.MirroredStrategy()

train_dataset, steps_per_epoch = utils.load_dataset(FLAGS.data_dir,
tfds.Split.TRAIN,
tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute]
FLAGS.batch_size)

eval_identifiers = ['tune', 'test1', 'test2']
splits = [tfds.Split.VALIDATION, tfds.Split.TEST, tfds.Split('test2')]
splits = [tfds.Split.VALIDATION, tfds.Split.TEST, tfds.Split('test2')] # pyrefly: ignore[missing-attribute]
eval_datasets, steps_per_eval = utils.load_eval_datasets(
eval_identifiers, splits, FLAGS.data_dir, FLAGS.batch_size)

Expand Down Expand Up @@ -380,7 +380,7 @@ def main(argv: Sequence[str]):
strategy=strategy,
summary_writer=summary_writer,
loss_type=FLAGS.loss_type,
graph_augmenter=graph_augmenter)
graph_augmenter=graph_augmenter) # pyrefly: ignore[bad-argument-type]


if __name__ == '__main__':
Expand Down
14 changes: 7 additions & 7 deletions baselines/drug_cardiotoxicity/sngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def per_replica_train_step_fn(inputs):
sample_weights = 1

with tf.GradientTape() as tape:
probs = model(features, training=True)
probs = model(features, training=True) # pyrefly: ignore[not-callable]
negative_log_likelihood = tf.reduce_mean(
tf.keras.losses.categorical_crossentropy(labels, probs) *
sample_weights)
Expand Down Expand Up @@ -215,11 +215,11 @@ def per_replica_eval_step_fn(inputs, dataset_name):
else:
features, labels = inputs

probs = model(features, training=False)
probs = model(features, training=False) # pyrefly: ignore[not-callable]
negative_log_likelihood = tf.reduce_mean(
tf.keras.losses.categorical_crossentropy(labels, probs))

metrics[f'{dataset_name}/negative_log_likelihood'].update_state(
metrics[f'{dataset_name}/negative_log_likelihood'].update_state( # pyrefly: ignore[missing-argument]
negative_log_likelihood)
metrics[f'{dataset_name}/accuracy'].update_state(labels, probs)
metrics[f'{dataset_name}/roc_auc'].update_state(labels[:, 1], probs[:, 1])
Expand Down Expand Up @@ -254,8 +254,8 @@ def distributed_eval_step(iterator, dataset_name, num_steps):
logging.info('Starting to run epoch: %s', epoch)
distributed_train_step(train_iterator)

current_step = (epoch + 1) * params.steps_per_epoch
max_steps = params.steps_per_epoch * params.num_epochs
current_step = (epoch + 1) * params.steps_per_epoch # pyrefly: ignore[unsupported-operation]
max_steps = params.steps_per_epoch * params.num_epochs # pyrefly: ignore[unsupported-operation]
time_elapsed = time.time() - start_time
steps_per_sec = float(current_step) / time_elapsed
eta_seconds = (max_steps - current_step) / steps_per_sec
Expand Down Expand Up @@ -305,11 +305,11 @@ def main(argv: Sequence[str]):
strategy = tf.distribute.MirroredStrategy()

train_dataset, steps_per_epoch = utils.load_dataset(FLAGS.data_dir,
tfds.Split.TRAIN,
tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute]
FLAGS.batch_size)

eval_identifiers = ['tune', 'test1', 'test2']
splits = [tfds.Split.VALIDATION, tfds.Split.TEST, tfds.Split('test2')]
splits = [tfds.Split.VALIDATION, tfds.Split.TEST, tfds.Split('test2')] # pyrefly: ignore[missing-attribute]
eval_datasets, steps_per_eval = utils.load_eval_datasets(
eval_identifiers, splits, FLAGS.data_dir, FLAGS.batch_size)

Expand Down
4 changes: 2 additions & 2 deletions baselines/drug_cardiotoxicity/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ def get_metric_result_value(metric):
def load_dataset(data_dir, split, batch_size):
"""Loads a single dataset with specific split."""
known_splits = [
tfds.Split.TRAIN, tfds.Split.VALIDATION, tfds.Split.TEST,
tfds.Split.TRAIN, tfds.Split.VALIDATION, tfds.Split.TEST, # pyrefly: ignore[missing-attribute]
tfds.Split('test2')
]
if split in known_splits:
is_training = split == tfds.Split.TRAIN
is_training = split == tfds.Split.TRAIN # pyrefly: ignore[missing-attribute]
else:
raise ValueError(
'Received ambiguous split {}, must set is_training for splits other '
Expand Down
32 changes: 16 additions & 16 deletions baselines/privileged_information/cifar_pi/distill_pi.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ def main(argv):
steps_per_validation = validation_builder.num_examples // batch_size
clean_test_builder = ub.datasets.get(
'cifar10' if FLAGS.dataset in ['cifar10n', 'cifar10h'] else 'cifar100',
split=tfds.Split.TEST
if FLAGS.dataset != 'cifar10h' else tfds.Split.TRAIN,
split=tfds.Split.TEST # pyrefly: ignore[missing-attribute]
if FLAGS.dataset != 'cifar10h' else tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute]
data_dir=data_dir,
drop_remainder=FLAGS.drop_remainder_for_eval,
is_training=False)
Expand Down Expand Up @@ -328,7 +328,7 @@ def main(argv):
f'{FLAGS.dataset}_corrupted',
corruption_type=corruption_type,
severity=severity,
split=tfds.Split.TEST,
split=tfds.Split.TEST, # pyrefly: ignore[missing-attribute]
data_dir=data_dir).load(batch_size=batch_size)
test_datasets[f'{corruption_type}_{severity}'] = (
strategy.experimental_distribute_dataset(dataset))
Expand Down Expand Up @@ -375,7 +375,7 @@ def annotator_labels_encoding(example):
'random_pi': lambda e: e['pi_features']['random_pi'],
}
privileged_information_fn = pi_utils.get_privileged_information_fn(
pi_subset=FLAGS.pi_subset.split(','), encoding_fn_dict=encoding_fn_dict)
pi_subset=FLAGS.pi_subset.split(','), encoding_fn_dict=encoding_fn_dict) # pyrefly: ignore[bad-argument-type]

pi_shape = privileged_information_fn(dummy_example).shape[1:]

Expand Down Expand Up @@ -470,12 +470,12 @@ def create_optimizer(epochs):
rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
})
if FLAGS.eval_on_ood:
ood_metrics = ood_utils.create_ood_metrics(ood_dataset_names)
ood_metrics = ood_utils.create_ood_metrics(ood_dataset_names) # pyrefly: ignore[unbound-name]
metrics.update(ood_metrics)
if FLAGS.corruptions_interval > 0:
corrupt_metrics = {}
for intensity in range(1, 6):
for corruption in corruption_types:
for corruption in corruption_types: # pyrefly: ignore[unbound-name]
dataset_name = '{0}_{1}'.format(corruption, intensity)
corrupt_metrics[f'distill/test/nll_{dataset_name}'] = (
tf.keras.metrics.Mean())
Expand Down Expand Up @@ -590,7 +590,7 @@ def update_train_metrics(labels,

sample_weight = None
if noise_split == '_clean':
sample_weight = 1. - noisy_idx
sample_weight = 1. - noisy_idx # pyrefly: ignore[unsupported-operation]
elif noise_split == '_noisy':
sample_weight = noisy_idx
metrics[f'{training_phase}/train/ece{noise_split}'].add_batch(
Expand All @@ -613,7 +613,7 @@ def step_fn(inputs):
noisy_idx = pi_utils.find_noisy_annotators(inputs)

with tf.GradientTape() as tape:
logits = teacher_model((images, privileged_information), training=True)
logits = teacher_model((images, privileged_information), training=True) # pyrefly: ignore[not-callable]

# Flatten the annotator axis.
logits = pi_utils.flatten_annotator_axis(logits)
Expand Down Expand Up @@ -668,9 +668,9 @@ def step_fn(inputs):
noisy_idx = pi_utils.find_noisy_annotators(inputs)

with tf.GradientTape() as tape:
teacher_logits = teacher_model((images, privileged_information),
teacher_logits = teacher_model((images, privileged_information), # pyrefly: ignore[not-callable]
training=False)
logits = distill_model(images, training=True)
logits = distill_model(images, training=True) # pyrefly: ignore[not-callable]

# teacher_logits: (batch_size, num_annotators, num_classes) ->
# (batch_size * num_annotators, num_classes)
Expand Down Expand Up @@ -755,7 +755,7 @@ def step_fn(inputs):
else:
labels = inputs['clean_labels']

logits = distill_model(images, training=False)
logits = distill_model(images, training=False) # pyrefly: ignore[not-callable]
# pytype: disable=attribute-error
logits = pi_utils.repeat_across_annotators(
logits,
Expand All @@ -764,7 +764,7 @@ def step_fn(inputs):
# pytype: enable=attribute-error
logits = pi_utils.flatten_annotator_axis(logits)
if use_annotator_labels:
logits = tf.gather(logits, non_empty_indices)
logits = tf.gather(logits, non_empty_indices) # pyrefly: ignore[unbound-name]
probs = tf.nn.softmax(logits)

negative_log_likelihood = tf.reduce_mean(
Expand All @@ -789,7 +789,7 @@ def step_fn(inputs):
images = inputs['features']
labels = inputs['labels']

logits = distill_model(images, training=False)
logits = distill_model(images, training=False) # pyrefly: ignore[not-callable]
probs = tf.nn.softmax(logits)

negative_log_likelihood = tf.reduce_mean(
Expand Down Expand Up @@ -941,19 +941,19 @@ def write_metrics_to_summary(total_results, epoch):
logging.info('Done with testing on %s', dataset_name)

if FLAGS.eval_on_ood:
for ood_dataset_name, ood_dataset in ood_datasets.items():
for ood_dataset_name, ood_dataset in ood_datasets.items(): # pyrefly: ignore[unbound-name]
ood_iterator = iter(ood_dataset)
logging.info('Calculating OOD on dataset %s', ood_dataset_name)
logging.info('Running OOD eval at epoch: %s', epoch)
test_step(ood_iterator, 'test', ood_dataset_name,
steps_per_ood[ood_dataset_name])
steps_per_ood[ood_dataset_name]) # pyrefly: ignore[unbound-name]

logging.info('Done with OOD eval on %s', ood_dataset_name)

corrupt_results = {}
if (FLAGS.corruptions_interval > 0 and
(epoch + 1) % FLAGS.corruptions_interval == 0):
corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics,
corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics, # pyrefly: ignore[unbound-name]
corruption_types)

logging.info('Distillation Loss: %.4f, Accuracy: %.2f%%',
Expand Down
Loading
Loading