BeRTOS
tftp.c
Go to the documentation of this file.
00001 
00040 #include "tftp.h"
00041 #include "cfg/cfg_tftp.h"
00042 #define LOG_LEVEL   TFTP_LOG_LEVEL
00043 #define LOG_FORMAT  TFTP_LOG_FORMAT
00044 #include <cfg/log.h>
00045 
00046 //#include <lwip/in.h>
00047 #include <lwip/inet.h>
00048 #include <lwip/sockets.h>
00049 #include <string.h> //memset
00050 
00051 #define TFTP_PACKET_SIZE 516
00052 
00053 #define DECLARE_TIMEOUT(name, timeout) \
00054     struct timeval name; \
00055     name.tv_sec = timeout / 1000; \
00056     name.tv_usec = (timeout % 1000) * 1000;
00057 
00058 #define KFT_TFTPSESSION MAKE_ID('T', 'F', 'T', 'P')
00059 INLINE TftpSession *TFTP_CAST(KFile *fd)
00060 {
00061     ASSERT(fd->_type == KFT_TFTPSESSION);
00062     return (TftpSession *)containerof(fd, TftpSession, kfile_request);
00063 }
00064 
00065 /*
00066  * Check if received data is correct and send ACK if ok.
00067  */
00068 static int checkPacket(TftpSession *ctx, const Tftpframe *frame)
00069 {
00070     LOG_INFO("Checking block %hd\n", ctx->block);
00071     if (ntohs(frame->hdr.opcode) != TFTP_DATA)
00072     {
00073         LOG_INFO("Opcode != TFTP_DATA (%hd != %d)\n", ntohs(frame->hdr.opcode), TFTP_DATA);
00074         return -1;
00075     }
00076     if (ntohs(frame->hdr.th_u.block) != ctx->block + 1)
00077         return -1;
00078 
00079     ctx->block++;
00080     // if everything was ok, send ACK
00081     // ACK is already in network order
00082     struct ackframe ack;
00083     ack.opcode = TFTP_ACK;
00084     ack.block_num = htons(ctx->block);
00085     ssize_t rc = lwip_sendto(ctx->sock, &ack, 4, 0, (struct sockaddr *)&ctx->addr, ctx->addr_len);
00086     if (rc == 4)
00087         return rc;
00088     else
00089         return -1;
00090 }
00091 
00092 /*
00093  * Return >0 if there's something to read in ctx, 0 on timeout, -1 on errors
00094  */
00095 static int tftp_waitEvent(TftpSession *ctx, struct timeval *timeout)
00096 {
00097     fd_set inset;
00098     FD_ZERO(&inset);
00099     FD_SET(ctx->sock, &inset);
00100     struct timeval tmp = *timeout;
00101     return lwip_select(ctx->sock + 1, &inset, NULL, NULL, &tmp);
00102 }
00103 
00104 /*
00105  * Read a block from TFTP.
00106  * \param size Must be exactly 516 bytes
00107  * \param timeout Time to wait the network connection, may be NULL to wait forever
00108  * \return Number of bytes read if success, TFTP_ERR_TIMEOUT on timeout, TFTP_ERR otherwise
00109  */
00110 static ssize_t tftp_readPacket(TftpSession *ctx, Tftpframe *frame, mtime_t timeout)
00111 {
00112     DECLARE_TIMEOUT(wait_tm, timeout);
00113 
00114     int res = tftp_waitEvent(ctx, &wait_tm);
00115     if (res == 0)
00116         return TFTP_ERR_TIMEOUT;
00117     if (res == -1)
00118         return TFTP_ERR;
00119 
00120     ssize_t rlen = lwip_recvfrom(ctx->sock, frame, sizeof(Tftpframe), 0, NULL, NULL);
00121     LOG_INFO("Received %zd bytes\n", rlen);
00122     if (rlen > 0 && (checkPacket(ctx, frame) > 0))
00123         return rlen;
00124     else
00125         return TFTP_ERR;
00126 }
00127 
00128 static size_t tftp_read(struct KFile *fd, void *buf, size_t size)
00129 {
00130     TftpSession *fds = TFTP_CAST(fd);
00131     uint8_t *_buf = (uint8_t *) buf;
00132     size_t read_bytes = 0;
00133     size_t offset = fds->valid_data - fds->bytes_available;
00134 
00135     if (fds->pending_ack)
00136     {
00137         ASSERT(fds->block == 0);
00138         struct ackframe ack;
00139         ack.opcode = TFTP_ACK;
00140         ack.block_num = fds->block;
00141         lwip_sendto(fds->sock, &ack, 4, 0, (struct sockaddr *)&fds->addr, fds->addr_len);
00142         fds->pending_ack = false;
00143     }
00144 
00145     if (fds->bytes_available < size)
00146     {
00147         /* check if we were called again after an error */
00148         if (fds->bytes_available > 0)
00149         {
00150             memcpy(_buf, fds->frame.data + offset, fds->bytes_available);
00151             LOG_INFO("ba < size. Copied %zd bytes from offset %zd\n", fds->bytes_available, offset);
00152             /* adjust buf and size */
00153             _buf += fds->bytes_available;
00154             size -= fds->bytes_available;
00155             read_bytes += fds->bytes_available;
00156         }
00157 
00158         if (!fds->is_xfer_end)
00159         {
00160             LOG_INFO("Waiting for new TFTP packet\n");
00161             /* get more data, we can wait since the function is blocking */
00162             ssize_t rd = tftp_readPacket(fds, &fds->frame, fds->timeout);
00163             if (rd < 0)
00164             {
00165                 fds->bytes_available = 0;
00166                 fds->error = rd;
00167                 return 0;
00168             }
00169             else
00170             {
00171                 if (rd < TFTP_PACKET_SIZE)
00172                 {
00173                     fds->is_xfer_end = true;
00174                     LOG_INFO("Received the last packet\n");
00175                 }
00176                 fds->bytes_available = (size_t)rd - sizeof(struct TftpHeader);
00177                 fds->valid_data = fds->bytes_available;
00178                 offset = 0;
00179             }
00180         }
00181         else
00182         {
00183             LOG_INFO("Transfer finished\n");
00184             fds->bytes_available -= fds->bytes_available;
00185             fds->valid_data = 0;
00186             return read_bytes;
00187         }
00188     }
00189 
00190     /* check how many bytes we need to copy */
00191     size_t res = MIN(fds->bytes_available, size);
00192     LOG_INFO("Copying %zd bytes from offset %zd\n", res, offset);
00193     memcpy(_buf, fds->frame.data + offset, res);
00194     fds->bytes_available -= res;
00195     read_bytes += res;
00196     return read_bytes;
00197 }
00198 
00199 static int tftp_error(struct KFile *fd)
00200 {
00201     TftpSession *fds = TFTP_CAST(fd);
00202     return fds->error;
00203 }
00204 
00205 static void tftp_clearerr(struct KFile *fd)
00206 {
00207     TftpSession *fds = TFTP_CAST(fd);
00208     fds->error = 0;
00209 }
00210 
00211 static int tftp_close(struct KFile *fd)
00212 {
00213     TftpSession *fds = TFTP_CAST(fd);
00214     struct errframe err;
00215     if (fds->pending_ack)
00216     {
00217         err.opcode = TFTP_PROTOERR;
00218         err.errcode = TFTP_PROTOERR_ACCESS_VIOLATION;
00219         err.str = '\0';
00220         lwip_sendto(fds->sock, &err, 5, 0, (struct sockaddr *)&fds->addr, fds->addr_len);
00221         LOG_INFO("Closed connection upon user request\n");
00222     }
00223     return 0;
00224 }
00225 
00226 static void resetTftpState(TftpSession *ctx)
00227 {
00228     ctx->block = 0;
00229     ctx->error = 0;
00230     ctx->bytes_available = 0;
00231     ctx->valid_data = 0;
00232     ctx->is_xfer_end = false;
00233     ctx->pending_ack = false;
00234 }
00235 
00247 KFile *tftp_listen(TftpSession *ctx, char *filename, size_t len, TftpOpenMode *mode)
00248 {
00249     DECLARE_TIMEOUT(wait_tm, ctx->timeout);
00250     resetTftpState(ctx);
00251 
00252     int res = tftp_waitEvent(ctx, &wait_tm);
00253     if (res == 0)
00254     {
00255         ctx->error = TFTP_ERR_TIMEOUT;
00256         return NULL;
00257     }
00258     if (res == -1)
00259     {
00260         ctx->error = TFTP_ERR;
00261         return NULL;
00262     }
00263 
00264     // listen onto TFTP port
00265     ctx->addr_len = sizeof(ctx->addr);
00266     ssize_t rd = 0;
00267     if ((rd = lwip_recvfrom(ctx->sock, &ctx->frame, sizeof(Tftpframe), 0, (struct sockaddr *)&ctx->addr, &ctx->addr_len)) > 0)
00268     {
00269         // check if the packet is WRQ, otherwise discard the packet
00270         if (ctx->frame.hdr.opcode == TFTP_WRQ)
00271         {
00272             *mode = TFTP_WRITE;
00273             ctx->pending_ack = true;
00274             strncpy(filename, (char *)&ctx->frame.hdr.th_u, len);
00275             filename[len - 1] = '\0';
00276             ctx->error = 0;
00277             return &ctx->kfile_request;
00278         }
00279         else
00280             *mode = TFTP_READ;
00281     }
00282     ctx->error = TFTP_ERR;
00283     return NULL;
00284 }
00285 
00296 int tftp_init(TftpSession *ctx, unsigned short port, mtime_t timeout)
00297 {
00298     DB(ctx->kfile_request._type = KFT_TFTPSESSION);
00299     ctx->kfile_request.read = tftp_read;
00300     ctx->kfile_request.error = tftp_error;
00301     ctx->kfile_request.clearerr = tftp_clearerr;
00302     ctx->kfile_request.close = tftp_close;
00303     resetTftpState(ctx);
00304 
00305     /* Unused kfile methods */
00306     ctx->kfile_request.seek = NULL;
00307     ctx->kfile_request.write = NULL;
00308     ctx->kfile_request.flush = NULL;
00309     ctx->kfile_request.reopen = NULL;
00310 
00311     struct sockaddr_in sa;
00312     sa.sin_family = AF_INET;
00313         sa.sin_addr.s_addr = htonl(INADDR_ANY);
00314         sa.sin_port = htons(port);
00315     ctx->timeout = timeout;
00316 
00317     ctx->sock = lwip_socket(AF_INET, SOCK_DGRAM, 0);
00318     if (ctx->sock == -1)
00319     {
00320         LOG_INFO("TFTP socket error\n");
00321         return -1;
00322     }
00323 
00324     if(lwip_bind(ctx->sock, (struct sockaddr *)&sa, sizeof(sa)))
00325     {
00326         LOG_INFO("Error binding socket\n");
00327         return -1;
00328     }
00329     return 0;
00330 }
00331