Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions custom/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/*
!/README.md
!/.gitignore
!/__init__.py
Empty file added custom/__init__.py
Empty file.
74 changes: 39 additions & 35 deletions dev_utils/mongo_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,48 +130,52 @@ def denormalize_files_from_reports(reports):
"""Pull the file info from the FILES_COLL collection in to associated parts of
the reports.
"""
# Make sure we have a list whose objects we can modify in place instead of a mongo
# cursor as returned from mongo_find.
reports = list(reports)
file_dicts = [
file_dict
for file_dict in itertools.chain.from_iterable(collect_file_dicts(report) for report in reports)
if FILE_REF_KEY in file_dict
]
if not file_dicts:
# These are likely partial reports (like for an ajax request of a specific
# part of the report), had a projection applied that does not include any file
# information, or only the old-style of storing file information is present in
# these documents.
return reports

file_refs = {file_dict[FILE_REF_KEY] for file_dict in file_dicts}

file_docs = {}
batch_size = 50
file_ref_iter = iter(file_refs)
while batch := tuple(itertools.islice(file_ref_iter, batch_size)):
# Reduce the size of the $in clause when there are large numbers of file refs by
# making multiple requests, passing batches of refs in.
for file_doc in mongo_find(FILES_COLL, {"_id": {"$in": batch}}, {TASK_IDS_KEY: 0}):
file_docs[file_doc.pop("_id")] = file_doc

for file_dict in file_dicts:
if file_dict[FILE_REF_KEY] not in file_docs:
log.warning("Failed to find %s in %s collection.", FILES_COLL, file_dict[FILE_REF_KEY])
continue
file_doc = file_docs[file_dict.pop(FILE_REF_KEY)]
file_dict.update(file_doc)

return reports
def denormalize_generator(reports_iterable):
# Optimization: Ensure we have an iterator to avoid infinite loops on lists
reports_iter = iter(reports_iterable)
batch_size = 50
while True:
# Grab a batch of reports from the cursor
reports_batch = list(itertools.islice(reports_iter, batch_size))
if not reports_batch:
break

file_dicts = [
file_dict
for file_dict in itertools.chain.from_iterable(collect_file_dicts(report) for report in reports_batch)
if FILE_REF_KEY in file_dict
]

if file_dicts:
file_refs = {file_dict[FILE_REF_KEY] for file_dict in file_dicts}
file_docs = {}
file_ref_batch_size = 50
file_ref_iter = iter(file_refs)
while batch := tuple(itertools.islice(file_ref_iter, file_ref_batch_size)):
# Reduce the size of the $in clause when there are large numbers of file refs by
# making multiple requests, passing batches of refs in.
for file_doc in mongo_find(FILES_COLL, {"_id": {"$in": batch}}, {TASK_IDS_KEY: 0}):
file_docs[file_doc.pop("_id")] = file_doc

for file_dict in file_dicts:
if file_dict[FILE_REF_KEY] not in file_docs:
log.warning("Failed to find %s in %s collection.", FILES_COLL, file_dict[FILE_REF_KEY])
continue
file_doc = file_docs[file_dict.pop(FILE_REF_KEY)]
file_dict.update(file_doc)

yield from reports_batch

return denormalize_generator(reports)


@mongo_hook(mongo_find_one, "analysis")
def denormalize_files(report):
"""Pull the file info from the FILES_COLL collection in to associated parts of
the report.
"""
denormalize_files_from_reports([report])
# Consume the generator so the report is denormalized in-place
list(denormalize_files_from_reports([report]))
return report


Expand Down
66 changes: 50 additions & 16 deletions dev_utils/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,53 @@

def connect_to_mongo() -> MongoClient:
try:
return MongoClient(
host=repconf.mongodb.get("host", "127.0.0.1"),
port=repconf.mongodb.get("port", 27017),
host = repconf.mongodb.get("host", "127.0.0.1")
port = repconf.mongodb.get("port", 27017)
client = MongoClient(
host=host,
port=port,
username=repconf.mongodb.get("username"),
password=repconf.mongodb.get("password"),
authSource=repconf.mongodb.get("authsource", "cuckoo"),
tlsCAFile=repconf.mongodb.get("tlscafile", None),
connect=False,
connect=True, # Force connection now to catch issues
serverSelectionTimeoutMS=5000,
socketTimeoutMS=30000,
)
except (ConnectionFailure, ServerSelectionTimeoutError):
log.error("Cannot connect to MongoDB")
# Ping the server to ensure it's alive
client.admin.command('ping')
log.info("Successfully connected to MongoDB at %s:%s", host, port)
return client
except (ConnectionFailure, ServerSelectionTimeoutError) as e:
log.error("Cannot connect to MongoDB: %s", e)
except Exception as e:
log.warning("Unable to connect to MongoDB database: %s, %s", mdb, e)

# code.interact(local=dict(locals(), **globals()))
# q = results_db.analysis.find({"info.id": 26}, {"memory": 1})
# https://pymongo.readthedocs.io/en/stable/changelog.html

conn = connect_to_mongo()
results_db = conn[mdb]
_client = None
_results_db = None

def get_mongodb():
global _client, _results_db
if _client is None:
_client = connect_to_mongo()
_results_db = _client[mdb]
return _results_db

# For legacy code that expects results_db to be an object
class LegacyDB:
@property
def analysis(self): return get_mongodb().analysis
@property
def calls(self): return get_mongodb().calls
@property
def files(self): return get_mongodb().files
def __getattr__(self, name): return getattr(get_mongodb(), name)

results_db = LegacyDB()

MAX_AUTO_RECONNECT_ATTEMPTS = 5

Expand Down Expand Up @@ -111,7 +138,7 @@ def mongo_insert_one(collection: str, doc):


@graceful_auto_reconnect
def mongo_find(collection: str, query, projection=False, sort=None, limit=None):
def mongo_find(collection: str, query, projection=False, sort=None, limit=None, no_hooks=False):
if sort is None:
sort = [("_id", -1)]

Expand All @@ -122,23 +149,30 @@ def mongo_find(collection: str, query, projection=False, sort=None, limit=None):
find_by = functools.partial(find_by, limit=limit)

result = find_by()
if result:
if result and not no_hooks:
for hook in hooks[mongo_find][collection]:
result = hook(result)
return result


@graceful_auto_reconnect
def mongo_find_one(collection: str, query, projection=False, sort=None):
def mongo_find_one(collection: str, query, projection=False, sort=None, max_time_ms=None, no_hooks=False):
if sort is None:
sort = [("_id", -1)]

kwargs = {"sort": sort}
if max_time_ms:
kwargs["max_time_ms"] = max_time_ms

if projection:
result = getattr(results_db, collection).find_one(query, projection, sort=sort)
result = getattr(results_db, collection).find_one(query, projection, **kwargs)
else:
result = getattr(results_db, collection).find_one(query, sort=sort)
if result:
result = getattr(results_db, collection).find_one(query, **kwargs)

if result and not no_hooks:
for hook in hooks[mongo_find_one][collection]:
result = hook(result)

return result


Expand Down Expand Up @@ -184,7 +218,7 @@ def mongo_find_one_and_update(collection, query, update, projection=None):

@graceful_auto_reconnect
def mongo_drop_database(database: str):
conn.drop_database(database)
get_mongodb().client.drop_database(database)


def mongo_delete_data(task_ids: int | Sequence[int]) -> None:
Expand Down Expand Up @@ -251,7 +285,7 @@ def mongo_delete_calls_by_task_id_in_range(*, range_start: int = 0, range_end: i
def mongo_is_cluster():
# This is only useful at the moment for clean to prevent destruction of cluster database
try:
conn.admin.command("listShards")
get_mongodb().client.admin.command("listShards")
return True
except OperationFailure:
return False
Expand Down
31 changes: 19 additions & 12 deletions utils/gcp_pubsub_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
import logging
import os
import sys
import tempfile
import shutil
import threading
import warnings

sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), ".."))

from lib.cuckoo.common.config import Config
from lib.cuckoo.common.constants import CUCKOO_ROOT
from lib.cuckoo.common.gcp import download_from_gcs
from lib.cuckoo.common.path_utils import path_exists
from lib.cuckoo.common.utils import store_temp_file
from lib.cuckoo.core.database import Database, init_database
from lib.cuckoo.core.startup import check_user_permissions
from utils.submit import submit_file
Expand All @@ -25,6 +27,14 @@
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s: %(message)s")
log = logging.getLogger("gcp_pubsub_service")

warnings.filterwarnings(
"ignore",
message="You are using a non-supported Python version",
category=FutureWarning,
module="google\\.api_core",
)


class GCPPubSubService:
def __init__(self):
self.gcp_cfg = Config("gcp")
Expand Down Expand Up @@ -57,9 +67,9 @@ def __init__(self):
self.db = Database()

def process_message(self, message):
local_path = None
try:
payload = json.loads(message.data.decode("utf-8"))
log.info("Received payload: %s", payload.get("uuid"))

sample_hash = payload.get("sample_hash")
gcs_uri = payload.get("gcs_uri")
Expand Down Expand Up @@ -90,17 +100,20 @@ def process_message(self, message):
# Check if sample exists locally
sample_hash = os.path.basename(sample_hash)
local_path = os.path.join(CUCKOO_ROOT, "storage", "binaries", sample_hash)
is_temp = False

if not path_exists(local_path):
log.info("Sample %s not found locally, fetching from GCS: %s", sample_hash, gcs_uri)
fd, temp_path = tempfile.mkstemp()
os.close(fd)
# Create a temporary path using store_temp_file with empty content
temp_path = store_temp_file(b"", sample_name)
if isinstance(temp_path, bytes):
temp_path = temp_path.decode()

if download_from_gcs(gcs_uri, temp_path):
local_path = temp_path
is_temp = True
else:
log.error("Failed to download sample from GCS")
if os.path.exists(os.path.dirname(temp_path)):
shutil.rmtree(os.path.dirname(temp_path))
message.nack()
return

Expand All @@ -123,12 +136,6 @@ def process_message(self, message):
except Exception as e:
log.error("Failed to add task to database: %s", e)
message.nack()
finally:
if is_temp and path_exists(local_path):
try:
os.unlink(local_path)
except Exception as e:
log.warning("Failed to delete temp file %s for task, %s: %s", local_path, payload.get("uuid"), e)

except Exception as e:
log.error("Error processing message: %s", e)
Expand Down
Loading
Loading