????
Your IP : 3.22.234.133
"""
Goal:
Download ip lists from server and save them into sqlite
(or somewhere else).
Data is downloaded via http api.
- sync_point - data, returned by server, and should be send in next request.
It allows:
- Serialization.
It can contain timestamp_from_s/timestamp_to_s fields, which does not
allow to skip any record (together with transactional http nature).
- Big data chunks handling.
If server will detect that chunk to be send is too big,
it will interrupt request. sync_point argument will allow to handle this
situation smoothly without any knowledge about it on agent side.
- purpose - ip lists with different purposes should be handled differently
(inserted into different ipsets, etc.)
They also can be downloaded by different modules.
- checksum - data to monitor agent's state. If server will detect that agent's
state is not correct (length of specific ip list does not match) it will
force ip list re-download from beginning.
That is why we should not handle any error - in case of incorrect agent's
state server will detect error and fix state.
"""
import asyncio
import ipaddress
import json
import os
import signal
import subprocess
from contextlib import suppress
from dataclasses import dataclass
from logging import getLogger
from typing import Optional, Sequence, Tuple
from defence360agent.contracts.license import LicenseCLN
from defence360agent.contracts.plugins import MessageSource
from defence360agent.internals.global_scope import g
from defence360agent.model import instance
from defence360agent.utils import (
recurring_check,
Scope,
split_for_chunk,
retry_on,
)
from defence360agent.utils.benchmark import Benchmark
from im360.api.server.remote_iplist import RemoteIPListAPI
from im360.contracts.config import Protector
from im360.internals.core.ipset.libipset import IPSetError
from im360.internals.core.ipset.sync import (
IPSetSyncIPListRecords,
IPSetSyncIPListPurpose,
)
from im360.model.firewall import IPListPurpose, IPListRecord
from im360.utils.net import pack_ip_network
from im360.utils.validate import NumericIPVersion
from peewee import SQL, fn
logger = getLogger(__name__)
ERROR_TIMEOUT = int(
os.environ.get("IMUNIFY360_REMOTE_IPLIST_ERROR_TIMEOUT", 600)
)
ERROR_LIMIT = float("inf")
DB_BATCH_SIZE = 200
IPSET_BATCH_SIZE = 10000
SYNC_POINT_FILE = "/var/imunify360/remote_iplist_sync_point.json"
PURPOSES = ["splashscreen", "captcha", "drop", "white"]
UNBLOCK_ALL = "ALL"
@dataclass
class IPSetDataToUpdate:
"""
Used to keep result data to update ipsets,
to separate database and ipsets processing.
"""
iplist_id = None
# ip versions to create iplist_id ipsets
iplist_id_to_create: Sequence[int] = tuple()
# ip versions to flush iplist_id ipsets
iplist_id_to_flush: Sequence[int] = tuple()
# ip versions to destroy iplist_id ipsets
iplist_id_versions_to_delete: Sequence[int] = tuple()
# ips to add to iplist_id ipsets
iplist_id_ips_to_add: Sequence[str] = tuple()
# ips to delete from iplist_id ipsets
iplist_id_ips_to_delete: Sequence[str] = tuple()
# [purpose, version] lists to delete iplist_id ipsets from sync ipsets
ipset_sync_iplist_ids_to_delete: Sequence[Tuple[str, int]] = tuple()
# [purpose, version] lists to add iplist_id ipsets to sync ipsets
ipset_sync_iplist_ids_to_add: Sequence[Tuple[str, int]] = tuple()
class RemoteIPListPlugin(MessageSource):
SCOPE = Scope.IM360_RESIDENT
async def create_source(self, loop, sink):
self.loop = loop
self.sink = sink
# note: uncaught errors are handled by the recurring check
self.update_task = self.loop.create_task(self.do_continuous_update())
async def shutdown(self):
self.update_task.cancel()
# note: suppress, to avoid sending it to Sentry
with suppress(asyncio.CancelledError):
await self.update_task
@recurring_check(ERROR_TIMEOUT, ERROR_LIMIT)
async def do_continuous_update(self):
while True:
if g.get("shutdown_started"):
logger.info("Shutdown started, stop iplist sync")
return
if not LicenseCLN.is_valid():
# note: don't send this to Sentry:
# - there are many events ~30k daily
# - nothing to do with this info at the moment
# (may be reconsidered)
# https://im360.sentry.cloudlinux.com/organizations/sentry/issues/1343/?project=3 # noqa
logger.warning("Skip iplist sync, since license is invalid")
return
try:
with Benchmark() as bench:
(
total_blocked_count,
total_unblocked_count,
sync_point,
delay,
) = await self.do_single_update()
# Log (for benchmark) & sleep for the requested by server delay:
# delay - elapsed_time
sleep_for = max(0, delay - bench.elapsed_time_ns // 10**9)
logger.info(
"Sync IPList response time characteristics:"
"Elapsed time: %.5gms. "
"Sleep for %s of %s",
bench.elapsed_time_ns * 1e-6,
sleep_for,
delay,
)
await asyncio.sleep(sleep_for)
except Exception as err:
if (
e := _find_cmd_error(err)
) and e.returncode == -signal.SIGTERM:
# Don't send SIGTERM to Sentry
# (presumably on shutdown on systemctl stop imunify360)
logger.warning(
"SIGTERM while updating remote lists, reason: %s"
", cmd error=%s",
err,
e,
)
await asyncio.sleep(ERROR_TIMEOUT)
else:
raise err
async def get_sync_from_server(
self, lazy_responses=True, raise_exception=False
):
# Iterable<
# {id: str, block?: str[], unblock?: 'ALL'|str[], purposes: str[]} |
# {id: str, unblock: 'ALL'} |
# {delay: int, sync_point: dict}
# >
sync_point = self._get_sync_point()
delay = ERROR_TIMEOUT
responses = await RemoteIPListAPI.sync(
LicenseCLN.get_server_id(),
PURPOSES,
sync_point,
self._get_checksum(),
LicenseCLN.get_token(),
lazy_responses=lazy_responses,
raise_exception=raise_exception,
)
return sync_point, delay, responses
async def do_single_update(self, sync=None, raise_exception=False):
if sync is None:
sync_point, delay, responses = await self.get_sync_from_server(
raise_exception=raise_exception
)
else:
sync_point, delay, responses = sync
# get lock while updating db, to avoid concurrent "restore"
# ipset updates by the recurring check
async with Protector.RULE_EDIT_LOCK:
ipsets_data_to_update = []
with instance.db.atomic(): # don't use async code (DEF-20047)
total_unblocked_count, total_blocked_count = 0, 0
for response in responses:
if "id" in response:
to_ipset = IPSetDataToUpdate()
to_ipset.iplist_id = response["id"]
unblocked_count, blocked_count = self._update_db(
response, to_ipset=to_ipset
)
ipsets_data_to_update.append(to_ipset)
total_unblocked_count += unblocked_count
total_blocked_count += blocked_count
elif "sync_point" in response:
# last response
logger.info("Processing response %s", response)
sync_point = response["sync_point"] # update the value
delay = response["delay"]
self._save_sync_point(sync_point)
await self.update_ipsets(ipsets_data_to_update)
logger.info(
"Sync IPList response processed: "
"unblocked=%d, blocked=%d, sync point=%s. ",
total_unblocked_count,
total_blocked_count,
sync_point,
)
return total_blocked_count, total_unblocked_count, sync_point, delay
async def update_ipsets(self, ipsets_data_to_update):
for to_ipset in ipsets_data_to_update:
await self._update_ipsets(to_ipset)
@staticmethod
async def sleep_on_error(_, attempt):
await asyncio.sleep(attempt * 2)
@retry_on(
[IPSetError],
max_tries=3,
timeout=2,
log=True,
silent=True,
on_error=sleep_on_error,
)
async def _update_ipsets(self, to_ipset: IPSetDataToUpdate):
assert to_ipset.iplist_id is not None, "iplist_id missed"
# NOTE: the sequence of actions is the same
# as the processing of the database (see _update_db method)
# unblock all
for version in sorted(to_ipset.iplist_id_to_flush):
await IPSetSyncIPListRecords().flush_ips(
to_ipset.iplist_id, f"ipv{version}"
)
# unblock
for unblock_chunk in split_for_chunk(
sorted(to_ipset.iplist_id_ips_to_delete), IPSET_BATCH_SIZE
):
await IPSetSyncIPListRecords().delete_ips(
to_ipset.iplist_id, unblock_chunk
)
# block
for version in sorted(to_ipset.iplist_id_to_create):
await IPSetSyncIPListRecords.create(to_ipset.iplist_id, version)
# add new ips to id iplist
for block_chunk in split_for_chunk(
sorted(to_ipset.iplist_id_ips_to_add), IPSET_BATCH_SIZE
):
await IPSetSyncIPListRecords().add_ips(
to_ipset.iplist_id, block_chunk
)
# remove empty id iplists from sync lists
for purpose, version in sorted(
to_ipset.ipset_sync_iplist_ids_to_delete
):
await IPSetSyncIPListPurpose().delete_id_iplist(
purpose, to_ipset.iplist_id, f"ipv{version}"
)
# add new id iplists to sync lists
for purpose, version in sorted(to_ipset.ipset_sync_iplist_ids_to_add):
await IPSetSyncIPListPurpose().add_id_iplist(
purpose, to_ipset.iplist_id, f"ipv{version}"
)
# destroy empty id iplists
for version in sorted(to_ipset.iplist_id_versions_to_delete):
await IPSetSyncIPListRecords().delete(
to_ipset.iplist_id, f"ipv{version}"
)
def _update_db(self, response, *, to_ipset: IPSetDataToUpdate):
unblocked_count, blocked_count = 0, 0
iplist_id = response["id"]
purposes = response.get("purposes")
unblock = response.get("unblock", [])
block = response.get("block", [])
logger.info(
"Processing response id=%s block=%s unblock=%s, purposes=%r",
iplist_id,
len(block),
unblock if unblock == UNBLOCK_ALL else len(unblock),
purposes,
)
records = (
IPListPurpose.select(
IPListPurpose.purpose,
IPListRecord.version,
)
.join(
IPListRecord,
on=(IPListPurpose.iplist_id == IPListRecord.iplist_id),
)
.where(IPListPurpose.iplist_id == iplist_id)
.distinct()
.tuples()
)
purpose_versions_set = set(records)
ip_versions_set = set(version for _, version in purpose_versions_set)
if unblock == UNBLOCK_ALL:
unblocked_count += self.unblock_all(
block, iplist_id, ip_versions_set, to_ipset=to_ipset
)
elif unblock:
unblocked_count += len(unblock)
self.unblock(
unblock, iplist_id, ip_versions_set, to_ipset=to_ipset
)
if block:
blocked_count += len(block)
self.block(block, iplist_id, purposes, to_ipset=to_ipset)
if not (unblock or block):
logger.error("Useless sync chunk: %s", response)
# check for empty id iplist for 4/6 ip version
records_after = (
IPListPurpose.select(
IPListPurpose.purpose,
IPListRecord.version,
)
.join(
IPListRecord,
on=(IPListPurpose.iplist_id == IPListRecord.iplist_id),
)
.where(IPListPurpose.iplist_id == iplist_id)
.distinct()
.tuples()
)
purpose_versions_set_after = set(records_after)
# remove empty id iplists from sync lists
to_ipset.ipset_sync_iplist_ids_to_delete = tuple(
purpose_versions_set - purpose_versions_set_after
)
# add new id iplists to sync lists
to_ipset.ipset_sync_iplist_ids_to_add = tuple(
purpose_versions_set_after - purpose_versions_set
)
# destroy empty id iplists
to_ipset.iplist_id_versions_to_delete = tuple(
set(r[1] for r in purpose_versions_set)
- set(r[1] for r in purpose_versions_set_after)
)
return unblocked_count, blocked_count
@staticmethod
def unblock_all(block, iplist_id, ip_versions_set, *, to_ipset):
# remove everything about iplist_id
unblocked_count = (
IPListRecord.delete()
.where(IPListRecord.iplist_id == iplist_id)
.execute()
)
to_ipset.iplist_id_to_flush = tuple(ip_versions_set)
if not block:
IPListPurpose.delete().where(
IPListPurpose.iplist_id == iplist_id
).execute()
return unblocked_count
@staticmethod
def unblock(unblock, iplist_id, ip_versions_set, *, to_ipset):
# DELETE FROM IPListRecord WHERE
# iplist_id=iplist_id and ip in unblock
for small_chunk in split_for_chunk(unblock, 10000):
q = IPListRecord.delete().where(
(IPListRecord.iplist_id == iplist_id)
& SQL(
"(network_address, netmask, version)",
).in_(
SQL(
"(VALUES {})".format(
",".join(
str(
pack_ip_network(
ipaddress.ip_network(ip),
)
)
for ip in small_chunk
)
)
)
)
)
q.execute()
to_ipset.iplist_id_ips_to_delete = tuple(
ip
for ip in unblock
if pack_ip_network(ipaddress.ip_network(ip))[-1] in ip_versions_set
)
@staticmethod
def block(block, iplist_id, purposes, *, to_ipset):
# INSERT INTO IPList (iplist_id, purposes)
# ON CONFLICT DO NOTHING
IPListPurpose.delete().where(
IPListPurpose.iplist_id == iplist_id
).execute()
for purpose in purposes:
IPListPurpose.insert(
iplist_id=iplist_id, purpose=purpose
).on_conflict_ignore().execute()
# INSERT INTO IPListRecord (iplist_id, ip)
# VALUES (iplist_id, ip) for ip in block
# ON CONFLICT DO NOTHING
ips_versions = {
NumericIPVersion.ipv4.value: False,
NumericIPVersion.ipv6.value: False,
}
for small_chunk in split_for_chunk(block, DB_BATCH_SIZE):
data = []
for ip in small_chunk:
net, mask, version = pack_ip_network(ipaddress.ip_network(ip))
ips_versions[version] = True
data.append(
{
"iplist_id": iplist_id,
"network_address": net,
"netmask": mask,
"version": version,
}
)
IPListRecord.insert_many(data).on_conflict_ignore().execute()
to_ipset.iplist_id_to_create = tuple(
version for version, is_used in ips_versions.items() if is_used
) # create id iplist, ignore error if it exists
# add new ips to id iplist
to_ipset.iplist_id_ips_to_add = tuple(block)
def _save_sync_point(self, sync_point):
with open(SYNC_POINT_FILE, "w") as f:
json.dump(sync_point, f)
def _get_sync_point(self):
try:
with open(SYNC_POINT_FILE) as f:
return json.load(f)
except Exception as e:
# if such approach will be considered for developing,
# we will have a lot of sync_point objects and probably will
# store them in sqlite
logger.info("Sync point is malformed, err=%s. Resetting.", e)
return {}
def _get_checksum(self):
"""Return number of items (ips) in each iplist."""
iplist_ids = [
rec.iplist_id
for rec in IPListPurpose.select(IPListPurpose.iplist_id)
.distinct()
.where(IPListPurpose.purpose.in_(PURPOSES))
.execute()
]
query = (
IPListRecord.select(
IPListRecord.iplist_id, fn.COUNT().alias("length")
)
.where(IPListRecord.iplist_id << iplist_ids)
.group_by(IPListRecord.iplist_id)
.execute()
)
checksum = {}
for record in query:
checksum[record.iplist_id] = record.length
logger.info("Checksum: %s", checksum)
return checksum
def _find_cmd_error(
e: BaseException,
) -> Optional[subprocess.CalledProcessError]:
"""Return command error if any.
Search related upstream errors recursively.
"""
if isinstance(e, subprocess.CalledProcessError):
return e
if upstream_e := (e.__cause__ or e.__context__):
return _find_cmd_error(upstream_e)
return None # not found