You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

247 lines
7.0 KiB

2 years ago
  1. /*
  2. * Copyright 2008, The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <errno.h>
  17. #include <stdlib.h>
  18. #include <string.h>
  19. #include <sys/socket.h>
  20. #include <sys/uio.h>
  21. #include <linux/if_ether.h>
  22. #include <linux/if_packet.h>
  23. #include <netinet/in.h>
  24. #include <netinet/ip.h>
  25. #include <netinet/udp.h>
  26. #include <unistd.h>
  27. #include <stdio.h>
  28. #include "dhcpmsg.h"
  29. int fatal();
  30. int open_raw_socket(const char *ifname __attribute__((unused)), uint8_t *hwaddr, int if_index)
  31. {
  32. int s;
  33. struct sockaddr_ll bindaddr;
  34. if((s = socket(PF_PACKET, SOCK_DGRAM, htons(ETH_P_IP))) < 0) {
  35. return fatal("socket(PF_PACKET)");
  36. }
  37. memset(&bindaddr, 0, sizeof(bindaddr));
  38. bindaddr.sll_family = AF_PACKET;
  39. bindaddr.sll_protocol = htons(ETH_P_IP);
  40. bindaddr.sll_halen = ETH_ALEN;
  41. memcpy(bindaddr.sll_addr, hwaddr, ETH_ALEN);
  42. bindaddr.sll_ifindex = if_index;
  43. if (bind(s, (struct sockaddr *)&bindaddr, sizeof(bindaddr)) < 0) {
  44. return fatal("Cannot bind raw socket to interface");
  45. }
  46. return s;
  47. }
  48. static uint32_t checksum(void *buffer, unsigned int count, uint32_t startsum)
  49. {
  50. uint16_t *up = (uint16_t *)buffer;
  51. uint32_t sum = startsum;
  52. uint32_t upper16;
  53. while (count > 1) {
  54. sum += *up++;
  55. count -= 2;
  56. }
  57. if (count > 0) {
  58. sum += (uint16_t) *(uint8_t *)up;
  59. }
  60. while ((upper16 = (sum >> 16)) != 0) {
  61. sum = (sum & 0xffff) + upper16;
  62. }
  63. return sum;
  64. }
  65. static uint32_t finish_sum(uint32_t sum)
  66. {
  67. return ~sum & 0xffff;
  68. }
  69. int send_packet(int s, int if_index, struct dhcp_msg *msg, int size,
  70. uint32_t saddr, uint32_t daddr, uint32_t sport, uint32_t dport)
  71. {
  72. struct iphdr ip;
  73. struct udphdr udp;
  74. struct iovec iov[3];
  75. uint32_t udpsum;
  76. uint16_t temp;
  77. struct msghdr msghdr;
  78. struct sockaddr_ll destaddr;
  79. ip.version = IPVERSION;
  80. ip.ihl = sizeof(ip) >> 2;
  81. ip.tos = 0;
  82. ip.tot_len = htons(sizeof(ip) + sizeof(udp) + size);
  83. ip.id = 0;
  84. ip.frag_off = 0;
  85. ip.ttl = IPDEFTTL;
  86. ip.protocol = IPPROTO_UDP;
  87. ip.check = 0;
  88. ip.saddr = saddr;
  89. ip.daddr = daddr;
  90. ip.check = finish_sum(checksum(&ip, sizeof(ip), 0));
  91. udp.source = htons(sport);
  92. udp.dest = htons(dport);
  93. udp.len = htons(sizeof(udp) + size);
  94. udp.check = 0;
  95. /* Calculate checksum for pseudo header */
  96. udpsum = checksum(&ip.saddr, sizeof(ip.saddr), 0);
  97. udpsum = checksum(&ip.daddr, sizeof(ip.daddr), udpsum);
  98. temp = htons(IPPROTO_UDP);
  99. udpsum = checksum(&temp, sizeof(temp), udpsum);
  100. temp = udp.len;
  101. udpsum = checksum(&temp, sizeof(temp), udpsum);
  102. /* Add in the checksum for the udp header */
  103. udpsum = checksum(&udp, sizeof(udp), udpsum);
  104. /* Add in the checksum for the data */
  105. udpsum = checksum(msg, size, udpsum);
  106. udp.check = finish_sum(udpsum);
  107. iov[0].iov_base = (char *)&ip;
  108. iov[0].iov_len = sizeof(ip);
  109. iov[1].iov_base = (char *)&udp;
  110. iov[1].iov_len = sizeof(udp);
  111. iov[2].iov_base = (char *)msg;
  112. iov[2].iov_len = size;
  113. memset(&destaddr, 0, sizeof(destaddr));
  114. destaddr.sll_family = AF_PACKET;
  115. destaddr.sll_protocol = htons(ETH_P_IP);
  116. destaddr.sll_ifindex = if_index;
  117. destaddr.sll_halen = ETH_ALEN;
  118. memcpy(destaddr.sll_addr, "\xff\xff\xff\xff\xff\xff", ETH_ALEN);
  119. msghdr.msg_name = &destaddr;
  120. msghdr.msg_namelen = sizeof(destaddr);
  121. msghdr.msg_iov = iov;
  122. msghdr.msg_iovlen = sizeof(iov) / sizeof(struct iovec);
  123. msghdr.msg_flags = 0;
  124. msghdr.msg_control = 0;
  125. msghdr.msg_controllen = 0;
  126. return sendmsg(s, &msghdr, 0);
  127. }
  128. int receive_packet(int s, struct dhcp_msg *msg)
  129. {
  130. int nread;
  131. int is_valid;
  132. struct dhcp_packet {
  133. struct iphdr ip;
  134. struct udphdr udp;
  135. struct dhcp_msg dhcp;
  136. } packet;
  137. int dhcp_size;
  138. uint32_t sum;
  139. uint16_t temp;
  140. uint32_t saddr, daddr;
  141. nread = read(s, &packet, sizeof(packet));
  142. if (nread < 0) {
  143. return -1;
  144. }
  145. /*
  146. * The raw packet interface gives us all packets received by the
  147. * network interface. We need to filter out all packets that are
  148. * not meant for us.
  149. */
  150. is_valid = 0;
  151. if (nread < (int)(sizeof(struct iphdr) + sizeof(struct udphdr))) {
  152. #if VERBOSE
  153. ALOGD("Packet is too small (%d) to be a UDP datagram", nread);
  154. #endif
  155. } else if (packet.ip.version != IPVERSION || packet.ip.ihl != (sizeof(packet.ip) >> 2)) {
  156. #if VERBOSE
  157. ALOGD("Not a valid IP packet");
  158. #endif
  159. } else if (nread < ntohs(packet.ip.tot_len)) {
  160. #if VERBOSE
  161. ALOGD("Packet was truncated (read %d, needed %d)", nread, ntohs(packet.ip.tot_len));
  162. #endif
  163. } else if (packet.ip.protocol != IPPROTO_UDP) {
  164. #if VERBOSE
  165. ALOGD("IP protocol (%d) is not UDP", packet.ip.protocol);
  166. #endif
  167. } else if (packet.udp.dest != htons(PORT_BOOTP_CLIENT)) {
  168. #if VERBOSE
  169. ALOGD("UDP dest port (%d) is not DHCP client", ntohs(packet.udp.dest));
  170. #endif
  171. } else {
  172. is_valid = 1;
  173. }
  174. if (!is_valid) {
  175. return -1;
  176. }
  177. /* Seems like it's probably a valid DHCP packet */
  178. /* validate IP header checksum */
  179. sum = finish_sum(checksum(&packet.ip, sizeof(packet.ip), 0));
  180. if (sum != 0) {
  181. printf("IP header checksum failure (0x%x)\n", packet.ip.check);
  182. return -1;
  183. }
  184. /*
  185. * Validate the UDP checksum.
  186. * Since we don't need the IP header anymore, we "borrow" it
  187. * to construct the pseudo header used in the checksum calculation.
  188. */
  189. dhcp_size = ntohs(packet.udp.len) - sizeof(packet.udp);
  190. /*
  191. * check validity of dhcp_size.
  192. * 1) cannot be negative or zero.
  193. * 2) src buffer contains enough bytes to copy
  194. * 3) cannot exceed destination buffer
  195. */
  196. if ((dhcp_size <= 0) ||
  197. ((int)(nread - sizeof(struct iphdr) - sizeof(struct udphdr)) < dhcp_size) ||
  198. ((int)sizeof(struct dhcp_msg) < dhcp_size)) {
  199. #if VERBOSE
  200. printf("Malformed Packet\n");
  201. #endif
  202. return -1;
  203. }
  204. saddr = packet.ip.saddr;
  205. daddr = packet.ip.daddr;
  206. nread = ntohs(packet.ip.tot_len);
  207. memset(&packet.ip, 0, sizeof(packet.ip));
  208. packet.ip.saddr = saddr;
  209. packet.ip.daddr = daddr;
  210. packet.ip.protocol = IPPROTO_UDP;
  211. packet.ip.tot_len = packet.udp.len;
  212. temp = packet.udp.check;
  213. packet.udp.check = 0;
  214. sum = finish_sum(checksum(&packet, nread, 0));
  215. packet.udp.check = temp;
  216. if (!sum)
  217. sum = finish_sum(sum);
  218. if (temp != sum) {
  219. printf("UDP header checksum failure (0x%x should be 0x%x)\n", sum, temp);
  220. return -1;
  221. }
  222. memcpy(msg, &packet.dhcp, dhcp_size);
  223. return dhcp_size;
  224. }