Back to Site
Loading...
Searching...
No Matches
mdns.h
Go to the documentation of this file.
1#pragma once
2
3#include "socket.h"
4#include "logger.h"
5#include <string>
6#include <vector>
7#include <memory>
8#include <functional>
9#include <thread>
10#include <atomic>
11#include <mutex>
12#include <chrono>
13#include <map>
14#include <cstdint>
15#include <condition_variable>
16#include <random>
17
18namespace librats {
19
20#define LOG_MDNS_DEBUG(message) LOG_DEBUG("mdns", message)
21#define LOG_MDNS_INFO(message) LOG_INFO("mdns", message)
22#define LOG_MDNS_WARN(message) LOG_WARN("mdns", message)
23#define LOG_MDNS_ERROR(message) LOG_ERROR("mdns", message)
24
25// mDNS protocol constants
26const uint16_t MDNS_PORT = 5353;
27const std::string MDNS_MULTICAST_IPv4 = "224.0.0.251";
28const std::string MDNS_MULTICAST_IPv6 = "ff02::fb";
29const std::string LIBRATS_SERVICE_TYPE = "_librats._tcp.local.";
30const std::string LIBRATS_SERVICE_INSTANCE_SUFFIX = ".local.";
31
32// DNS record types
33enum class DnsRecordType : uint16_t {
34 A = 1,
35 PTR = 12,
36 TXT = 16,
37 AAAA = 28,
38 SRV = 33
39};
40
41// DNS record classes
42enum class DnsRecordClass : uint16_t {
43 CLASS_IN = 1,
44 CLASS_IN_FLUSH = 0x8001 // Cache flush bit set
45};
46
47// mDNS message flags
48enum class MdnsFlags : uint16_t {
49 QUERY = 0x0000,
50 RESPONSE = 0x8000,
51 AUTHORITATIVE = 0x8400
52};
53
54// Discovered service structure
56 std::string service_name; // e.g., "rats-node-abc123._librats._tcp.local."
57 std::string host_name; // e.g., "MyComputer.local."
58 std::string ip_address; // IPv4 or IPv6 address
59 uint16_t port; // Service port
60 std::map<std::string, std::string> txt_records; // TXT record key-value pairs
61 std::chrono::steady_clock::time_point last_seen;
62
63 MdnsService() : port(0) {}
64
65 MdnsService(const std::string& name, const std::string& host,
66 const std::string& ip, uint16_t p)
67 : service_name(name), host_name(host), ip_address(ip), port(p),
68 last_seen(std::chrono::steady_clock::now()) {}
69};
70
71// DNS message header
83
84// DNS question structure
94
95// DNS resource record structure
97 std::string name;
100 uint32_t ttl;
101 std::vector<uint8_t> data;
102 size_t data_offset_in_packet; // Offset of data in original packet for DNS compression
103
105 DnsResourceRecord(const std::string& n, DnsRecordType t, DnsRecordClass c, uint32_t ttl_val)
106 : name(n), type(t), record_class(c), ttl(ttl_val), data_offset_in_packet(0) {}
107};
108
109// Complete DNS message structure
112 std::vector<DnsQuestion> questions;
113 std::vector<DnsResourceRecord> answers;
114 std::vector<DnsResourceRecord> authorities;
115 std::vector<DnsResourceRecord> additionals;
116 std::vector<uint8_t> raw_packet; // Original packet for DNS compression resolution
117
118 DnsMessage() = default;
119};
120
121// mDNS service discovery callback
122using MdnsServiceCallback = std::function<void(const MdnsService& service, bool is_new)>;
123
125public:
126 explicit MdnsClient(const std::string& service_instance_name = "", uint16_t service_port = 0);
128
129 // Core functionality
130 bool start();
131 void stop();
133 bool is_running() const;
134
135 // Service announcement
136 bool announce_service(const std::string& instance_name, uint16_t port,
137 const std::map<std::string, std::string>& txt_records = {});
139 bool is_announcing() const;
140
141 // Service discovery
145 bool is_discovering() const;
146
147 // Query for specific services
149
150 // Get discovered services
151 std::vector<MdnsService> get_discovered_services() const;
152 std::vector<MdnsService> get_recent_services(std::chrono::seconds max_age = std::chrono::seconds(300)) const;
153 void clear_old_services(std::chrono::seconds max_age = std::chrono::seconds(600));
154
155 // Configuration
156 void set_announcement_interval(std::chrono::seconds interval);
157 void set_query_interval(std::chrono::seconds interval);
158 void set_service_type(const std::string& service_type);
159
160private:
161 // Core properties
162 std::string service_instance_name_;
163 uint16_t service_port_;
164 std::string service_type_; // Dynamic service type (e.g., "_rats-search._tcp.local.")
165 std::map<std::string, std::string> txt_records_;
166
167 // Network properties
168 socket_t multicast_socket_;
169 std::string local_hostname_;
170 std::string local_ip_address_;
171
172 // Threading and state
173 std::atomic<bool> running_;
174 std::atomic<bool> announcing_;
175 std::atomic<bool> discovering_;
176 std::thread receiver_thread_;
177 std::thread announcer_thread_;
178 std::thread querier_thread_;
179
180 // Conditional variables for immediate shutdown
181 std::condition_variable shutdown_cv_;
182 std::mutex shutdown_mutex_;
183
184 // Discovery state
185 mutable std::mutex services_mutex_;
186 std::map<std::string, MdnsService> discovered_services_;
187 MdnsServiceCallback service_callback_;
188
189 // Timing configuration
190 std::chrono::seconds announcement_interval_;
191 std::chrono::seconds query_interval_;
192
193 // Random number generator for response delays
194 mutable std::mt19937 rng_;
195
196 // Socket operations
197 bool create_multicast_socket();
198 bool join_multicast_group();
199 bool leave_multicast_group();
200 void close_multicast_socket();
201
202 // Message handling threads
203 void receiver_loop();
204 void announcer_loop();
205 void querier_loop();
206
207 // Packet processing
208 void handle_received_packet(const std::vector<uint8_t>& packet, const std::string& sender_ip);
209 void process_mdns_message(const DnsMessage& message, const std::string& sender_ip);
210 void process_query(const DnsMessage& query, const std::string& sender_ip);
211 void process_response(const DnsMessage& response, const std::string& sender_ip);
212
213 // Service processing
214 void extract_service_from_response(const DnsMessage& response, const std::string& sender_ip);
215 bool is_librats_service(const std::string& service_name) const;
216 void add_or_update_service(const MdnsService& service);
217
218 // Message creation
219 DnsMessage create_query_message();
220 DnsMessage create_announcement_message();
221 DnsMessage create_response_message(const DnsQuestion& question);
222
223 // DNS record creation
224 DnsResourceRecord create_ptr_record(const std::string& service_type, const std::string& instance_name, uint32_t ttl = 120);
225 DnsResourceRecord create_srv_record(const std::string& instance_name, const std::string& hostname, uint16_t port, uint32_t ttl = 120);
226 DnsResourceRecord create_txt_record(const std::string& instance_name, const std::map<std::string, std::string>& txt_data, uint32_t ttl = 120);
227 DnsResourceRecord create_a_record(const std::string& hostname, const std::string& ip_address, uint32_t ttl = 120);
228
229 // DNS serialization/deserialization
230 std::vector<uint8_t> serialize_dns_message(const DnsMessage& message);
231 bool deserialize_dns_message(const std::vector<uint8_t>& data, DnsMessage& message);
232 bool read_resource_records(const std::vector<uint8_t>& data, size_t& offset,
233 uint16_t count, std::vector<DnsResourceRecord>& records);
234
235 // DNS name compression helpers
236 void write_dns_name(std::vector<uint8_t>& buffer, const std::string& name);
237 std::string read_dns_name(const std::vector<uint8_t>& buffer, size_t& offset);
238 void write_uint16(std::vector<uint8_t>& buffer, uint16_t value);
239 void write_uint32(std::vector<uint8_t>& buffer, uint32_t value);
240 uint16_t read_uint16(const std::vector<uint8_t>& buffer, size_t& offset);
241 uint32_t read_uint32(const std::vector<uint8_t>& buffer, size_t& offset);
242
243 // TXT record helpers
244 std::vector<uint8_t> encode_txt_record(const std::map<std::string, std::string>& txt_data);
245 std::map<std::string, std::string> decode_txt_record(const std::vector<uint8_t>& txt_data);
246
247 // SRV record helpers
248 std::vector<uint8_t> encode_srv_record(uint16_t priority, uint16_t weight, uint16_t port, const std::string& target);
249 bool decode_srv_record(const std::vector<uint8_t>& full_packet, size_t data_offset,
250 uint16_t& priority, uint16_t& weight, uint16_t& port, std::string& target);
251
252 // Utility functions
253 std::string get_local_hostname();
254 std::string get_local_ip_address();
255 std::string create_service_instance_name(const std::string& instance_name);
256 std::string extract_instance_name_from_service(const std::string& service_name);
257 bool send_multicast_packet(const std::vector<uint8_t>& packet);
258
259 // Name validation
260 bool is_valid_dns_name(const std::string& name) const;
261 std::string normalize_dns_name(const std::string& name) const;
262};
263
264} // namespace librats
void clear_old_services(std::chrono::seconds max_age=std::chrono::seconds(600))
std::vector< MdnsService > get_discovered_services() const
bool is_running() const
void set_query_interval(std::chrono::seconds interval)
bool announce_service(const std::string &instance_name, uint16_t port, const std::map< std::string, std::string > &txt_records={})
void set_service_callback(MdnsServiceCallback callback)
void set_service_type(const std::string &service_type)
bool is_discovering() const
std::vector< MdnsService > get_recent_services(std::chrono::seconds max_age=std::chrono::seconds(300)) const
void set_announcement_interval(std::chrono::seconds interval)
bool is_announcing() const
MdnsClient(const std::string &service_instance_name="", uint16_t service_port=0)
std::function< void(const MdnsService &service, bool is_new)> MdnsServiceCallback
Definition mdns.h:122
const std::string MDNS_MULTICAST_IPv6
Definition mdns.h:28
const std::string LIBRATS_SERVICE_INSTANCE_SUFFIX
Definition mdns.h:30
MdnsFlags
Definition mdns.h:48
const uint16_t MDNS_PORT
Definition mdns.h:26
DnsRecordClass
Definition mdns.h:42
DnsRecordType
Definition mdns.h:33
const std::string MDNS_MULTICAST_IPv4
Definition mdns.h:27
const std::string LIBRATS_SERVICE_TYPE
Definition mdns.h:29
STL namespace.
int socket_t
Definition socket.h:22
uint16_t additional_count
Definition mdns.h:78
uint16_t answer_count
Definition mdns.h:76
uint16_t flags
Definition mdns.h:74
uint16_t question_count
Definition mdns.h:75
uint16_t transaction_id
Definition mdns.h:73
uint16_t authority_count
Definition mdns.h:77
std::vector< uint8_t > raw_packet
Definition mdns.h:116
std::vector< DnsResourceRecord > authorities
Definition mdns.h:114
DnsHeader header
Definition mdns.h:111
std::vector< DnsResourceRecord > answers
Definition mdns.h:113
std::vector< DnsQuestion > questions
Definition mdns.h:112
std::vector< DnsResourceRecord > additionals
Definition mdns.h:115
DnsRecordType type
Definition mdns.h:87
DnsRecordClass record_class
Definition mdns.h:88
std::string name
Definition mdns.h:86
DnsQuestion(const std::string &n, DnsRecordType t, DnsRecordClass c)
Definition mdns.h:91
DnsRecordType type
Definition mdns.h:98
DnsRecordClass record_class
Definition mdns.h:99
DnsResourceRecord(const std::string &n, DnsRecordType t, DnsRecordClass c, uint32_t ttl_val)
Definition mdns.h:105
std::vector< uint8_t > data
Definition mdns.h:101
std::string service_name
Definition mdns.h:56
std::chrono::steady_clock::time_point last_seen
Definition mdns.h:61
std::string ip_address
Definition mdns.h:58
MdnsService(const std::string &name, const std::string &host, const std::string &ip, uint16_t p)
Definition mdns.h:65
std::map< std::string, std::string > txt_records
Definition mdns.h:60
std::string host_name
Definition mdns.h:57
uint16_t port
Definition mdns.h:59