Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion rust_snuba/src/consumer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::time::Duration;

use chrono::{DateTime, Utc};
Expand Down Expand Up @@ -270,6 +270,7 @@ pub fn consumer_impl(
join_timeout_ms,
health_check: health_check.to_string(),
use_row_binary,
assigned_partitions: Arc::new(Mutex::new(Vec::new())),
};

let processor = StreamProcessor::with_kafka(config, factory, topic, dlq_policy);
Expand Down
9 changes: 8 additions & 1 deletion rust_snuba/src/factory_v2.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::time::Duration;

use sentry::{Hub, SentryFutureExt};
Expand Down Expand Up @@ -62,6 +62,7 @@ pub struct ConsumerStrategyFactoryV2 {
pub join_timeout_ms: Option<u64>,
pub health_check: String,
pub use_row_binary: bool,
pub assigned_partitions: Arc<Mutex<Vec<u16>>>,
}

impl ProcessingStrategyFactory<KafkaPayload> for ConsumerStrategyFactoryV2 {
Expand All @@ -82,6 +83,8 @@ impl ProcessingStrategyFactory<KafkaPayload> for ConsumerStrategyFactoryV2 {
Some(min) => set_global_tag("min_partition".to_owned(), min.to_string()),
None => set_global_tag("min_partition".to_owned(), "none".to_owned()),
}

*self.assigned_partitions.lock().unwrap() = assigned_partitions;
}

fn create(&self) -> Box<dyn ProcessingStrategy<KafkaPayload>> {
Expand All @@ -106,6 +109,7 @@ impl ProcessingStrategyFactory<KafkaPayload> for ConsumerStrategyFactoryV2 {

let next_step: Box<dyn ProcessingStrategy<BytesInsertBatch<()>>> =
if let Some((ref producer, destination)) = self.commit_log_producer {
let partitions = self.assigned_partitions.lock().unwrap().clone();
Box::new(ProduceCommitLog::new(
next_step,
producer.clone(),
Expand All @@ -114,6 +118,7 @@ impl ProcessingStrategyFactory<KafkaPayload> for ConsumerStrategyFactoryV2 {
self.physical_consumer_group.clone(),
&self.commitlog_concurrency,
false,
partitions,
))
} else {
Box::new(next_step)
Expand Down Expand Up @@ -290,6 +295,7 @@ impl ConsumerStrategyFactoryV2 {

let next_step: Box<dyn ProcessingStrategy<BytesInsertBatch<()>>> =
if let Some((ref producer, destination)) = self.commit_log_producer {
let partitions = self.assigned_partitions.lock().unwrap().clone();
Box::new(ProduceCommitLog::new(
next_step,
producer.clone(),
Expand All @@ -298,6 +304,7 @@ impl ConsumerStrategyFactoryV2 {
self.physical_consumer_group.clone(),
&self.commitlog_concurrency,
false,
partitions,
))
} else {
Box::new(next_step)
Expand Down
200 changes: 198 additions & 2 deletions rust_snuba/src/strategies/commit_log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ use sentry_arroyo::processing::strategies::{
};
use sentry_arroyo::types::{Message, Topic, TopicOrPartition};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::str;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use thiserror::Error;

Expand Down Expand Up @@ -72,6 +73,8 @@ struct ProduceMessage {
topic: Topic,
consumer_group: String,
skip_produce: bool,
assigned_partitions: Vec<u16>,
last_produced_offsets: Arc<Mutex<BTreeMap<u16, (u64, DateTime<Utc>)>>>,
}

impl ProduceMessage {
Expand All @@ -81,13 +84,16 @@ impl ProduceMessage {
topic: Topic,
consumer_group: String,
skip_produce: bool,
assigned_partitions: Vec<u16>,
) -> Self {
ProduceMessage {
producer,
destination,
topic,
consumer_group,
skip_produce,
assigned_partitions,
last_produced_offsets: Arc::new(Mutex::new(BTreeMap::new())),
}
}
}
Expand All @@ -102,6 +108,8 @@ impl TaskRunner<BytesInsertBatch<()>, BytesInsertBatch<()>, anyhow::Error> for P
let topic = self.topic;
let skip_produce = self.skip_produce;
let consumer_group = self.consumer_group.clone();
let assigned_partitions = self.assigned_partitions.clone();
let last_produced_offsets = self.last_produced_offsets.clone();

let commit_log_offsets = message.payload().commit_log_offsets().clone();

Expand All @@ -110,6 +118,9 @@ impl TaskRunner<BytesInsertBatch<()>, BytesInsertBatch<()>, anyhow::Error> for P
return Ok(message);
}

let partitions_in_batch: std::collections::BTreeSet<u16> =
commit_log_offsets.0.keys().copied().collect();

for (partition, mut entry) in commit_log_offsets.0 {
entry.received_p99.sort();
let received_p99 = entry
Expand All @@ -132,6 +143,37 @@ impl TaskRunner<BytesInsertBatch<()>, BytesInsertBatch<()>, anyhow::Error> for P
tracing::error!(error, "Error producing message");
return Err(RunTaskError::RetryableError);
}

// Update last produced offset for this partition
last_produced_offsets
.lock()
.unwrap()
.insert(partition, (entry.offset, entry.orig_message_ts));
}

// Produce heartbeat entries for idle assigned partitions
let offsets = last_produced_offsets.lock().unwrap().clone();
for &partition in &assigned_partitions {
if !partitions_in_batch.contains(&partition) {
if let Some(&(offset, orig_message_ts)) = offsets.get(&partition) {
let commit = Commit {
topic: topic.to_string(),
partition,
group: consumer_group.clone(),
orig_message_ts,
offset,
received_p99: None,
};

let payload = commit.try_into().unwrap();

if let Err(err) = producer.produce(&destination, payload) {
let error: &dyn std::error::Error = &err;
tracing::error!(error, "Error producing heartbeat message");
return Err(RunTaskError::RetryableError);
}
}
}
}

Ok(message)
Expand All @@ -155,10 +197,18 @@ where
consumer_group: String,
concurrency: &ConcurrencyConfig,
skip_produce: bool,
assigned_partitions: Vec<u16>,
) -> Self {
let inner = RunTaskInThreads::new(
next_step,
ProduceMessage::new(producer, destination, topic, consumer_group, skip_produce),
ProduceMessage::new(
producer,
destination,
topic,
consumer_group,
skip_produce,
assigned_partitions,
),
concurrency,
Some("produce_commit_log"),
);
Expand Down Expand Up @@ -327,6 +377,7 @@ mod tests {
"group1".to_string(),
&concurrency,
false,
vec![0, 1],
);

for payload in payloads {
Expand All @@ -345,4 +396,149 @@ mod tests {
assert_eq!(produced[1].0, "test:0:group1");
assert_eq!(produced[2].0, "test:1:group1");
}

#[test]
fn produce_commit_log_heartbeat_for_idle_partitions() {
// Assigned partitions 0-3, but only partitions 0 and 1 have data.
// After the first batch establishes offsets for partitions 2 and 3,
// a second batch with only partitions 0 and 1 should still produce
// heartbeat entries for partitions 2 and 3.
let produced_payloads = Arc::new(Mutex::new(Vec::new()));

struct MockProducer {
pub payloads: Arc<Mutex<Vec<(String, KafkaPayload)>>>,
}

impl Producer<KafkaPayload> for MockProducer {
fn produce(
&self,
topic: &TopicOrPartition,
payload: KafkaPayload,
) -> Result<(), ProducerError> {
assert_eq!(topic.topic().as_str(), "test-commitlog");
self.payloads.lock().unwrap().push((
str::from_utf8(payload.key().unwrap()).unwrap().to_owned(),
payload,
));
Ok(())
}
}

// Batch 1: all 4 partitions have data (establishes last known offsets)
let batch1 = BytesInsertBatch::from_rows(())
.with_message_timestamp(Utc::now())
.with_commit_log_offsets(CommitLogOffsets(BTreeMap::from([
(
0,
CommitLogEntry {
offset: 100,
orig_message_ts: Utc::now(),
received_p99: Vec::new(),
},
),
(
1,
CommitLogEntry {
offset: 200,
orig_message_ts: Utc::now(),
received_p99: Vec::new(),
},
),
(
2,
CommitLogEntry {
offset: 300,
orig_message_ts: Utc::now(),
received_p99: Vec::new(),
},
),
(
3,
CommitLogEntry {
offset: 400,
orig_message_ts: Utc::now(),
received_p99: Vec::new(),
},
),
])));

// Batch 2: only partitions 0 and 1 have data
let batch2 = BytesInsertBatch::from_rows(())
.with_message_timestamp(Utc::now())
.with_commit_log_offsets(CommitLogOffsets(BTreeMap::from([
(
0,
CommitLogEntry {
offset: 500,
orig_message_ts: Utc::now(),
received_p99: Vec::new(),
},
),
(
1,
CommitLogEntry {
offset: 600,
orig_message_ts: Utc::now(),
received_p99: Vec::new(),
},
),
])));

let producer = MockProducer {
payloads: produced_payloads.clone(),
};

let next_step = TestStrategy::new();

let concurrency = ConcurrencyConfig::new(1);
let mut strategy = ProduceCommitLog::new(
next_step,
Arc::new(producer),
Topic::new("test-commitlog"),
Topic::new("test"),
"group1".to_string(),
&concurrency,
false,
vec![0, 1, 2, 3], // assigned partitions 0-3
);

// Submit and process batch 1
strategy
.submit(Message::new_any_message(batch1, BTreeMap::new()))
.unwrap();
strategy.poll().unwrap();

// Submit and process batch 2
strategy
.submit(Message::new_any_message(batch2, BTreeMap::new()))
.unwrap();
strategy.poll().unwrap();

strategy.join(None).unwrap();

let produced = produced_payloads.lock().unwrap();

// Batch 1: 4 entries (all partitions have data, no heartbeats needed)
// Batch 2: 2 entries (partitions 0,1 data) + 2 heartbeats (partitions 2,3)
assert_eq!(produced.len(), 8);

// Batch 1 entries
assert_eq!(produced[0].0, "test:0:group1");
assert_eq!(produced[1].0, "test:1:group1");
assert_eq!(produced[2].0, "test:2:group1");
assert_eq!(produced[3].0, "test:3:group1");

// Batch 2: partitions 0 and 1 (data), then heartbeats for 2 and 3
assert_eq!(produced[4].0, "test:0:group1");
assert_eq!(produced[5].0, "test:1:group1");
assert_eq!(produced[6].0, "test:2:group1");
assert_eq!(produced[7].0, "test:3:group1");

// Verify heartbeat offsets match the last known offsets from batch 1
let heartbeat_2: Commit = produced[6].1.clone().try_into().unwrap();
assert_eq!(heartbeat_2.offset, 300);

let heartbeat_3: Commit = produced[7].1.clone().try_into().unwrap();
assert_eq!(heartbeat_3.offset, 400);
}
}
Loading
Loading