Staging: hv: osd: remove MemAlloc wrapper
[sfrench/cifs-2.6.git] / drivers / staging / hv / RndisFilter.c
1 /*
2  *
3  * Copyright (c) 2009, Microsoft Corporation.
4  *
5  * This program is free software; you can redistribute it and/or modify it
6  * under the terms and conditions of the GNU General Public License,
7  * version 2, as published by the Free Software Foundation.
8  *
9  * This program is distributed in the hope it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
12  * more details.
13  *
14  * You should have received a copy of the GNU General Public License along with
15  * this program; if not, write to the Free Software Foundation, Inc., 59 Temple
16  * Place - Suite 330, Boston, MA 02111-1307 USA.
17  *
18  * Authors:
19  *   Haiyang Zhang <haiyangz@microsoft.com>
20  *   Hank Janssen  <hjanssen@microsoft.com>
21  *
22  */
23
24 #define KERNEL_2_6_27
25
26 #include <linux/kernel.h>
27 #include <linux/mm.h>
28 #include "include/logging.h"
29
30 #include "include/NetVscApi.h"
31 #include "RndisFilter.h"
32
33 //
34 // Data types
35 //
36
37 typedef struct _RNDIS_FILTER_DRIVER_OBJECT {
38         // The original driver
39         NETVSC_DRIVER_OBJECT            InnerDriver;
40
41 } RNDIS_FILTER_DRIVER_OBJECT;
42
43 typedef enum {
44         RNDIS_DEV_UNINITIALIZED = 0,
45         RNDIS_DEV_INITIALIZING,
46         RNDIS_DEV_INITIALIZED,
47         RNDIS_DEV_DATAINITIALIZED,
48 } RNDIS_DEVICE_STATE;
49
50 typedef struct _RNDIS_DEVICE {
51         NETVSC_DEVICE                   *NetDevice;
52
53         RNDIS_DEVICE_STATE              State;
54         u32                                     LinkStatus;
55         u32                                     NewRequestId;
56
57         HANDLE                                  RequestLock;
58         LIST_ENTRY                              RequestList;
59
60         unsigned char                                   HwMacAddr[HW_MACADDR_LEN];
61 } RNDIS_DEVICE;
62
63
64 typedef struct _RNDIS_REQUEST {
65         LIST_ENTRY                                      ListEntry;
66         HANDLE                                          WaitEvent;
67
68         // FIXME: We assumed a fixed size response here. If we do ever need to handle a bigger response,
69         // we can either define a max response message or add a response buffer variable above this field
70         RNDIS_MESSAGE                           ResponseMessage;
71
72         // Simplify allocation by having a netvsc packet inline
73         NETVSC_PACKET                           Packet;
74         PAGE_BUFFER                                     Buffer;
75         // FIXME: We assumed a fixed size request here.
76         RNDIS_MESSAGE                           RequestMessage;
77 } RNDIS_REQUEST;
78
79
80 typedef struct _RNDIS_FILTER_PACKET {
81         void                                            *CompletionContext;
82         PFN_ON_SENDRECVCOMPLETION       OnCompletion;
83
84         RNDIS_MESSAGE                           Message;
85 } RNDIS_FILTER_PACKET;
86
87 //
88 // Internal routines
89 //
90 static int
91 RndisFilterSendRequest(
92         RNDIS_DEVICE    *Device,
93         RNDIS_REQUEST   *Request
94         );
95
96 static void
97 RndisFilterReceiveResponse(
98         RNDIS_DEVICE    *Device,
99         RNDIS_MESSAGE   *Response
100         );
101
102 static void
103 RndisFilterReceiveIndicateStatus(
104         RNDIS_DEVICE    *Device,
105         RNDIS_MESSAGE   *Response
106         );
107
108 static void
109 RndisFilterReceiveData(
110         RNDIS_DEVICE    *Device,
111         RNDIS_MESSAGE   *Message,
112         NETVSC_PACKET   *Packet
113         );
114
115 static int
116 RndisFilterOnReceive(
117         DEVICE_OBJECT           *Device,
118         NETVSC_PACKET           *Packet
119         );
120
121 static int
122 RndisFilterQueryDevice(
123         RNDIS_DEVICE    *Device,
124         u32                     Oid,
125         void                    *Result,
126         u32                     *ResultSize
127         );
128
129 static inline int
130 RndisFilterQueryDeviceMac(
131         RNDIS_DEVICE    *Device
132         );
133
134 static inline int
135 RndisFilterQueryDeviceLinkStatus(
136         RNDIS_DEVICE    *Device
137         );
138
139 static int
140 RndisFilterSetPacketFilter(
141         RNDIS_DEVICE    *Device,
142         u32                     NewFilter
143         );
144
145 static int
146 RndisFilterInitDevice(
147         RNDIS_DEVICE            *Device
148         );
149
150 static int
151 RndisFilterOpenDevice(
152         RNDIS_DEVICE            *Device
153         );
154
155 static int
156 RndisFilterCloseDevice(
157         RNDIS_DEVICE            *Device
158         );
159
160 static int
161 RndisFilterOnDeviceAdd(
162         DEVICE_OBJECT   *Device,
163         void                    *AdditionalInfo
164         );
165
166 static int
167 RndisFilterOnDeviceRemove(
168         DEVICE_OBJECT *Device
169         );
170
171 static void
172 RndisFilterOnCleanup(
173         DRIVER_OBJECT *Driver
174         );
175
176 static int
177 RndisFilterOnOpen(
178         DEVICE_OBJECT           *Device
179         );
180
181 static int
182 RndisFilterOnClose(
183         DEVICE_OBJECT           *Device
184         );
185
186 static int
187 RndisFilterOnSend(
188         DEVICE_OBJECT           *Device,
189         NETVSC_PACKET           *Packet
190         );
191
192 static void
193 RndisFilterOnSendCompletion(
194    void *Context
195         );
196
197 static void
198 RndisFilterOnSendRequestCompletion(
199    void *Context
200         );
201
202 //
203 // Global var
204 //
205
206 // The one and only
207 RNDIS_FILTER_DRIVER_OBJECT gRndisFilter;
208
209 static inline RNDIS_DEVICE* GetRndisDevice(void)
210 {
211         RNDIS_DEVICE *device;
212
213         device = MemAllocZeroed(sizeof(RNDIS_DEVICE));
214         if (!device)
215         {
216                 return NULL;
217         }
218
219         device->RequestLock = SpinlockCreate();
220         if (!device->RequestLock)
221         {
222                 MemFree(device);
223                 return NULL;
224         }
225
226         INITIALIZE_LIST_HEAD(&device->RequestList);
227
228         device->State = RNDIS_DEV_UNINITIALIZED;
229
230         return device;
231 }
232
233 static inline void PutRndisDevice(RNDIS_DEVICE *Device)
234 {
235         SpinlockClose(Device->RequestLock);
236         MemFree(Device);
237 }
238
239 static inline RNDIS_REQUEST* GetRndisRequest(RNDIS_DEVICE *Device, u32 MessageType, u32 MessageLength)
240 {
241         RNDIS_REQUEST *request;
242         RNDIS_MESSAGE *rndisMessage;
243         RNDIS_SET_REQUEST *set;
244
245         request = MemAllocZeroed(sizeof(RNDIS_REQUEST));
246         if (!request)
247         {
248                 return NULL;
249         }
250
251         request->WaitEvent = WaitEventCreate();
252         if (!request->WaitEvent)
253         {
254                 MemFree(request);
255                 return NULL;
256         }
257
258         rndisMessage = &request->RequestMessage;
259         rndisMessage->NdisMessageType = MessageType;
260         rndisMessage->MessageLength = MessageLength;
261
262         // Set the request id. This field is always after the rndis header for request/response packet types so
263         // we just used the SetRequest as a template
264         set = &rndisMessage->Message.SetRequest;
265         set->RequestId = InterlockedIncrement((int*)&Device->NewRequestId);
266
267         // Add to the request list
268         SpinlockAcquire(Device->RequestLock);
269         INSERT_TAIL_LIST(&Device->RequestList, &request->ListEntry);
270         SpinlockRelease(Device->RequestLock);
271
272         return request;
273 }
274
275 static inline void PutRndisRequest(RNDIS_DEVICE *Device, RNDIS_REQUEST *Request)
276 {
277         SpinlockAcquire(Device->RequestLock);
278         REMOVE_ENTRY_LIST(&Request->ListEntry);
279         SpinlockRelease(Device->RequestLock);
280
281         WaitEventClose(Request->WaitEvent);
282         MemFree(Request);
283 }
284
285 static inline void DumpRndisMessage(RNDIS_MESSAGE *RndisMessage)
286 {
287         switch (RndisMessage->NdisMessageType)
288         {
289         case REMOTE_NDIS_PACKET_MSG:
290                 DPRINT_DBG(NETVSC, "REMOTE_NDIS_PACKET_MSG (len %u, data offset %u data len %u, # oob %u, oob offset %u, oob len %u, pkt offset %u, pkt len %u",
291                         RndisMessage->MessageLength,
292                         RndisMessage->Message.Packet.DataOffset,
293                         RndisMessage->Message.Packet.DataLength,
294                         RndisMessage->Message.Packet.NumOOBDataElements,
295                         RndisMessage->Message.Packet.OOBDataOffset,
296                         RndisMessage->Message.Packet.OOBDataLength,
297                         RndisMessage->Message.Packet.PerPacketInfoOffset,
298                         RndisMessage->Message.Packet.PerPacketInfoLength);
299                 break;
300
301         case REMOTE_NDIS_INITIALIZE_CMPLT:
302                 DPRINT_DBG(NETVSC, "REMOTE_NDIS_INITIALIZE_CMPLT (len %u, id 0x%x, status 0x%x, major %d, minor %d, device flags %d, max xfer size 0x%x, max pkts %u, pkt aligned %u)",
303                         RndisMessage->MessageLength,
304                         RndisMessage->Message.InitializeComplete.RequestId,
305                         RndisMessage->Message.InitializeComplete.Status,
306                         RndisMessage->Message.InitializeComplete.MajorVersion,
307                         RndisMessage->Message.InitializeComplete.MinorVersion,
308                         RndisMessage->Message.InitializeComplete.DeviceFlags,
309                         RndisMessage->Message.InitializeComplete.MaxTransferSize,
310                         RndisMessage->Message.InitializeComplete.MaxPacketsPerMessage,
311                         RndisMessage->Message.InitializeComplete.PacketAlignmentFactor);
312                 break;
313
314         case REMOTE_NDIS_QUERY_CMPLT:
315                 DPRINT_DBG(NETVSC, "REMOTE_NDIS_QUERY_CMPLT (len %u, id 0x%x, status 0x%x, buf len %u, buf offset %u)",
316                         RndisMessage->MessageLength,
317                         RndisMessage->Message.QueryComplete.RequestId,
318                         RndisMessage->Message.QueryComplete.Status,
319                         RndisMessage->Message.QueryComplete.InformationBufferLength,
320                         RndisMessage->Message.QueryComplete.InformationBufferOffset);
321                 break;
322
323         case REMOTE_NDIS_SET_CMPLT:
324                 DPRINT_DBG(NETVSC, "REMOTE_NDIS_SET_CMPLT (len %u, id 0x%x, status 0x%x)",
325                         RndisMessage->MessageLength,
326                         RndisMessage->Message.SetComplete.RequestId,
327                         RndisMessage->Message.SetComplete.Status);
328                 break;
329
330         case REMOTE_NDIS_INDICATE_STATUS_MSG:
331                 DPRINT_DBG(NETVSC, "REMOTE_NDIS_INDICATE_STATUS_MSG (len %u, status 0x%x, buf len %u, buf offset %u)",
332                         RndisMessage->MessageLength,
333                         RndisMessage->Message.IndicateStatus.Status,
334                         RndisMessage->Message.IndicateStatus.StatusBufferLength,
335                         RndisMessage->Message.IndicateStatus.StatusBufferOffset);
336                 break;
337
338         default:
339                 DPRINT_DBG(NETVSC, "0x%x (len %u)",
340                         RndisMessage->NdisMessageType,
341                         RndisMessage->MessageLength);
342                 break;
343         }
344 }
345
346 static int
347 RndisFilterSendRequest(
348         RNDIS_DEVICE    *Device,
349         RNDIS_REQUEST   *Request
350         )
351 {
352         int ret=0;
353         NETVSC_PACKET *packet;
354
355         DPRINT_ENTER(NETVSC);
356
357         // Setup the packet to send it
358         packet = &Request->Packet;
359
360         packet->IsDataPacket = false;
361         packet->TotalDataBufferLength = Request->RequestMessage.MessageLength;
362         packet->PageBufferCount = 1;
363
364         packet->PageBuffers[0].Pfn = GetPhysicalAddress(&Request->RequestMessage) >> PAGE_SHIFT;
365         packet->PageBuffers[0].Length = Request->RequestMessage.MessageLength;
366         packet->PageBuffers[0].Offset = (unsigned long)&Request->RequestMessage & (PAGE_SIZE -1);
367
368         packet->Completion.Send.SendCompletionContext = Request;//packet;
369         packet->Completion.Send.OnSendCompletion = RndisFilterOnSendRequestCompletion;
370         packet->Completion.Send.SendCompletionTid = (unsigned long)Device;
371
372         ret = gRndisFilter.InnerDriver.OnSend(Device->NetDevice->Device, packet);
373         DPRINT_EXIT(NETVSC);
374         return ret;
375 }
376
377
378 static void
379 RndisFilterReceiveResponse(
380         RNDIS_DEVICE    *Device,
381         RNDIS_MESSAGE   *Response
382         )
383 {
384         LIST_ENTRY *anchor;
385         LIST_ENTRY *curr;
386         RNDIS_REQUEST *request=NULL;
387         bool found = false;
388
389         DPRINT_ENTER(NETVSC);
390
391         SpinlockAcquire(Device->RequestLock);
392         ITERATE_LIST_ENTRIES(anchor, curr, &Device->RequestList)
393         {
394                 request = CONTAINING_RECORD(curr, RNDIS_REQUEST, ListEntry);
395
396                 // All request/response message contains RequestId as the 1st field
397                 if (request->RequestMessage.Message.InitializeRequest.RequestId == Response->Message.InitializeComplete.RequestId)
398                 {
399                         DPRINT_DBG(NETVSC, "found rndis request for this response (id 0x%x req type 0x%x res type 0x%x)",
400                                 request->RequestMessage.Message.InitializeRequest.RequestId, request->RequestMessage.NdisMessageType, Response->NdisMessageType);
401
402                         found = true;
403                         break;
404                 }
405         }
406         SpinlockRelease(Device->RequestLock);
407
408         if (found)
409         {
410                 if (Response->MessageLength <= sizeof(RNDIS_MESSAGE))
411                 {
412                         memcpy(&request->ResponseMessage, Response, Response->MessageLength);
413                 }
414                 else
415                 {
416                         DPRINT_ERR(NETVSC, "rndis response buffer overflow detected (size %u max %u)", Response->MessageLength, sizeof(RNDIS_FILTER_PACKET));
417
418                         if (Response->NdisMessageType == REMOTE_NDIS_RESET_CMPLT) // does not have a request id field
419                         {
420                                 request->ResponseMessage.Message.ResetComplete.Status = STATUS_BUFFER_OVERFLOW;
421                         }
422                         else
423                         {
424                                 request->ResponseMessage.Message.InitializeComplete.Status = STATUS_BUFFER_OVERFLOW;
425                         }
426                 }
427
428                 WaitEventSet(request->WaitEvent);
429         }
430         else
431         {
432                 DPRINT_ERR(NETVSC, "no rndis request found for this response (id 0x%x res type 0x%x)",
433                                 Response->Message.InitializeComplete.RequestId, Response->NdisMessageType);
434         }
435
436         DPRINT_EXIT(NETVSC);
437 }
438
439 static void
440 RndisFilterReceiveIndicateStatus(
441         RNDIS_DEVICE    *Device,
442         RNDIS_MESSAGE   *Response
443         )
444 {
445         RNDIS_INDICATE_STATUS *indicate = &Response->Message.IndicateStatus;
446
447         if (indicate->Status == RNDIS_STATUS_MEDIA_CONNECT)
448         {
449                 gRndisFilter.InnerDriver.OnLinkStatusChanged(Device->NetDevice->Device, 1);
450         }
451         else if (indicate->Status == RNDIS_STATUS_MEDIA_DISCONNECT)
452         {
453                 gRndisFilter.InnerDriver.OnLinkStatusChanged(Device->NetDevice->Device, 0);
454         }
455         else
456         {
457                 // TODO:
458         }
459 }
460
461 static void
462 RndisFilterReceiveData(
463         RNDIS_DEVICE    *Device,
464         RNDIS_MESSAGE   *Message,
465         NETVSC_PACKET   *Packet
466         )
467 {
468         RNDIS_PACKET *rndisPacket;
469         u32 dataOffset;
470
471         DPRINT_ENTER(NETVSC);
472
473         // empty ethernet frame ??
474         ASSERT(Packet->PageBuffers[0].Length > RNDIS_MESSAGE_SIZE(RNDIS_PACKET));
475
476         rndisPacket = &Message->Message.Packet;
477
478         // FIXME: Handle multiple rndis pkt msgs that maybe enclosed in this
479         // netvsc packet (ie TotalDataBufferLength != MessageLength)
480
481         // Remove the rndis header and pass it back up the stack
482         dataOffset = RNDIS_HEADER_SIZE + rndisPacket->DataOffset;
483
484         Packet->TotalDataBufferLength -= dataOffset;
485         Packet->PageBuffers[0].Offset += dataOffset;
486         Packet->PageBuffers[0].Length -= dataOffset;
487
488         Packet->IsDataPacket = true;
489
490         gRndisFilter.InnerDriver.OnReceiveCallback(Device->NetDevice->Device, Packet);
491
492         DPRINT_EXIT(NETVSC);
493 }
494
495 static int
496 RndisFilterOnReceive(
497         DEVICE_OBJECT           *Device,
498         NETVSC_PACKET           *Packet
499         )
500 {
501         NETVSC_DEVICE *netDevice = (NETVSC_DEVICE*)Device->Extension;
502         RNDIS_DEVICE *rndisDevice;
503         RNDIS_MESSAGE rndisMessage;
504         RNDIS_MESSAGE *rndisHeader;
505
506         DPRINT_ENTER(NETVSC);
507
508         ASSERT(netDevice);
509         //Make sure the rndis device state is initialized
510         if (!netDevice->Extension)
511         {
512                 DPRINT_ERR(NETVSC, "got rndis message but no rndis device...dropping this message!");
513                 DPRINT_EXIT(NETVSC);
514                 return -1;
515         }
516
517         rndisDevice = (RNDIS_DEVICE*)netDevice->Extension;
518         if (rndisDevice->State == RNDIS_DEV_UNINITIALIZED)
519         {
520                 DPRINT_ERR(NETVSC, "got rndis message but rndis device uninitialized...dropping this message!");
521                 DPRINT_EXIT(NETVSC);
522                 return -1;
523         }
524
525         rndisHeader = (RNDIS_MESSAGE*)PageMapVirtualAddress(Packet->PageBuffers[0].Pfn);
526
527         rndisHeader = (void*)((unsigned long)rndisHeader + Packet->PageBuffers[0].Offset);
528
529         // Make sure we got a valid rndis message
530         // FIXME: There seems to be a bug in set completion msg where its MessageLength is 16 bytes but
531         // the ByteCount field in the xfer page range shows 52 bytes
532 #if 0
533         if ( Packet->TotalDataBufferLength != rndisHeader->MessageLength )
534         {
535                 PageUnmapVirtualAddress((void*)(unsigned long)rndisHeader - Packet->PageBuffers[0].Offset);
536
537                 DPRINT_ERR(NETVSC, "invalid rndis message? (expected %u bytes got %u)...dropping this message!",
538                         rndisHeader->MessageLength, Packet->TotalDataBufferLength);
539                 DPRINT_EXIT(NETVSC);
540                 return -1;
541         }
542 #endif
543
544         if ((rndisHeader->NdisMessageType != REMOTE_NDIS_PACKET_MSG) && (rndisHeader->MessageLength > sizeof(RNDIS_MESSAGE)))
545         {
546                 DPRINT_ERR(NETVSC, "incoming rndis message buffer overflow detected (got %u, max %u)...marking it an error!",
547                         rndisHeader->MessageLength, sizeof(RNDIS_MESSAGE));
548         }
549
550         memcpy(&rndisMessage, rndisHeader, (rndisHeader->MessageLength > sizeof(RNDIS_MESSAGE))?sizeof(RNDIS_MESSAGE):rndisHeader->MessageLength);
551
552         PageUnmapVirtualAddress((void*)(unsigned long)rndisHeader - Packet->PageBuffers[0].Offset);
553
554         DumpRndisMessage(&rndisMessage);
555
556         switch (rndisMessage.NdisMessageType)
557         {
558                 // data msg
559         case REMOTE_NDIS_PACKET_MSG:
560                 RndisFilterReceiveData(rndisDevice, &rndisMessage, Packet);
561                 break;
562
563                 // completion msgs
564         case REMOTE_NDIS_INITIALIZE_CMPLT:
565         case REMOTE_NDIS_QUERY_CMPLT:
566         case REMOTE_NDIS_SET_CMPLT:
567         //case REMOTE_NDIS_RESET_CMPLT:
568         //case REMOTE_NDIS_KEEPALIVE_CMPLT:
569                 RndisFilterReceiveResponse(rndisDevice, &rndisMessage);
570                 break;
571
572                 // notification msgs
573         case REMOTE_NDIS_INDICATE_STATUS_MSG:
574                 RndisFilterReceiveIndicateStatus(rndisDevice, &rndisMessage);
575                 break;
576         default:
577                 DPRINT_ERR(NETVSC, "unhandled rndis message (type %u len %u)", rndisMessage.NdisMessageType, rndisMessage.MessageLength);
578                 break;
579         }
580
581         DPRINT_EXIT(NETVSC);
582         return 0;
583 }
584
585
586 static int
587 RndisFilterQueryDevice(
588         RNDIS_DEVICE    *Device,
589         u32                     Oid,
590         void                    *Result,
591         u32                     *ResultSize
592         )
593 {
594         RNDIS_REQUEST *request;
595         u32 inresultSize = *ResultSize;
596         RNDIS_QUERY_REQUEST *query;
597         RNDIS_QUERY_COMPLETE *queryComplete;
598         int ret=0;
599
600         DPRINT_ENTER(NETVSC);
601
602         ASSERT(Result);
603
604         *ResultSize = 0;
605         request = GetRndisRequest(Device, REMOTE_NDIS_QUERY_MSG, RNDIS_MESSAGE_SIZE(RNDIS_QUERY_REQUEST));
606         if (!request)
607         {
608                 ret = -1;
609                 goto Cleanup;
610         }
611
612         // Setup the rndis query
613         query = &request->RequestMessage.Message.QueryRequest;
614         query->Oid = Oid;
615         query->InformationBufferOffset = sizeof(RNDIS_QUERY_REQUEST);
616         query->InformationBufferLength = 0;
617         query->DeviceVcHandle = 0;
618
619         ret = RndisFilterSendRequest(Device, request);
620         if (ret != 0)
621         {
622                 goto Cleanup;
623         }
624
625         WaitEventWait(request->WaitEvent);
626
627         // Copy the response back
628         queryComplete = &request->ResponseMessage.Message.QueryComplete;
629
630         if (queryComplete->InformationBufferLength > inresultSize)
631         {
632                 ret = -1;
633                 goto Cleanup;
634         }
635
636         memcpy(Result,
637                         (void*)((unsigned long)queryComplete + queryComplete->InformationBufferOffset),
638                         queryComplete->InformationBufferLength);
639
640         *ResultSize = queryComplete->InformationBufferLength;
641
642 Cleanup:
643         if (request)
644         {
645                 PutRndisRequest(Device, request);
646         }
647         DPRINT_EXIT(NETVSC);
648
649         return ret;
650 }
651
652 static inline int
653 RndisFilterQueryDeviceMac(
654         RNDIS_DEVICE    *Device
655         )
656 {
657         u32 size=HW_MACADDR_LEN;
658
659         return RndisFilterQueryDevice(Device,
660                                                                         RNDIS_OID_802_3_PERMANENT_ADDRESS,
661                                                                         Device->HwMacAddr,
662                                                                         &size);
663 }
664
665 static inline int
666 RndisFilterQueryDeviceLinkStatus(
667         RNDIS_DEVICE    *Device
668         )
669 {
670         u32 size=sizeof(u32);
671
672         return RndisFilterQueryDevice(Device,
673                                                                         RNDIS_OID_GEN_MEDIA_CONNECT_STATUS,
674                                                                         &Device->LinkStatus,
675                                                                         &size);
676 }
677
678 static int
679 RndisFilterSetPacketFilter(
680         RNDIS_DEVICE    *Device,
681         u32                     NewFilter
682         )
683 {
684         RNDIS_REQUEST *request;
685         RNDIS_SET_REQUEST *set;
686         RNDIS_SET_COMPLETE *setComplete;
687         u32 status;
688         int ret;
689
690         DPRINT_ENTER(NETVSC);
691
692         ASSERT(RNDIS_MESSAGE_SIZE(RNDIS_SET_REQUEST) + sizeof(u32) <= sizeof(RNDIS_MESSAGE));
693
694         request = GetRndisRequest(Device, REMOTE_NDIS_SET_MSG, RNDIS_MESSAGE_SIZE(RNDIS_SET_REQUEST) + sizeof(u32));
695         if (!request)
696         {
697                 ret = -1;
698                 goto Cleanup;
699         }
700
701         // Setup the rndis set
702         set = &request->RequestMessage.Message.SetRequest;
703         set->Oid = RNDIS_OID_GEN_CURRENT_PACKET_FILTER;
704         set->InformationBufferLength = sizeof(u32);
705         set->InformationBufferOffset = sizeof(RNDIS_SET_REQUEST);
706
707         memcpy((void*)(unsigned long)set + sizeof(RNDIS_SET_REQUEST), &NewFilter, sizeof(u32));
708
709         ret = RndisFilterSendRequest(Device, request);
710         if (ret != 0)
711         {
712                 goto Cleanup;
713         }
714
715         ret = WaitEventWaitEx(request->WaitEvent, 2000/*2sec*/);
716         if (!ret)
717         {
718                 ret = -1;
719                 DPRINT_ERR(NETVSC, "timeout before we got a set response...");
720                 // We cant deallocate the request since we may still receive a send completion for it.
721                 goto Exit;
722         }
723         else
724         {
725                 if (ret > 0)
726                 {
727                         ret = 0;
728                 }
729                 setComplete = &request->ResponseMessage.Message.SetComplete;
730                 status = setComplete->Status;
731         }
732
733 Cleanup:
734         if (request)
735         {
736                 PutRndisRequest(Device, request);
737         }
738 Exit:
739         DPRINT_EXIT(NETVSC);
740
741         return ret;
742 }
743
744 int
745 RndisFilterInit(
746         NETVSC_DRIVER_OBJECT    *Driver
747         )
748 {
749         DPRINT_ENTER(NETVSC);
750
751         DPRINT_DBG(NETVSC, "sizeof(RNDIS_FILTER_PACKET) == %d", sizeof(RNDIS_FILTER_PACKET));
752
753         Driver->RequestExtSize = sizeof(RNDIS_FILTER_PACKET);
754         Driver->AdditionalRequestPageBufferCount = 1; // For rndis header
755
756         //Driver->Context = rndisDriver;
757
758         memset(&gRndisFilter, 0, sizeof(RNDIS_FILTER_DRIVER_OBJECT));
759
760         /*rndisDriver->Driver = Driver;
761
762         ASSERT(Driver->OnLinkStatusChanged);
763         rndisDriver->OnLinkStatusChanged = Driver->OnLinkStatusChanged;*/
764
765         // Save the original dispatch handlers before we override it
766         gRndisFilter.InnerDriver.Base.OnDeviceAdd = Driver->Base.OnDeviceAdd;
767         gRndisFilter.InnerDriver.Base.OnDeviceRemove = Driver->Base.OnDeviceRemove;
768         gRndisFilter.InnerDriver.Base.OnCleanup = Driver->Base.OnCleanup;
769
770         ASSERT(Driver->OnSend);
771         ASSERT(Driver->OnReceiveCallback);
772         gRndisFilter.InnerDriver.OnSend = Driver->OnSend;
773         gRndisFilter.InnerDriver.OnReceiveCallback = Driver->OnReceiveCallback;
774         gRndisFilter.InnerDriver.OnLinkStatusChanged = Driver->OnLinkStatusChanged;
775
776         // Override
777         Driver->Base.OnDeviceAdd = RndisFilterOnDeviceAdd;
778         Driver->Base.OnDeviceRemove = RndisFilterOnDeviceRemove;
779         Driver->Base.OnCleanup = RndisFilterOnCleanup;
780         Driver->OnSend = RndisFilterOnSend;
781         Driver->OnOpen = RndisFilterOnOpen;
782         Driver->OnClose = RndisFilterOnClose;
783         //Driver->QueryLinkStatus = RndisFilterQueryDeviceLinkStatus;
784         Driver->OnReceiveCallback = RndisFilterOnReceive;
785
786         DPRINT_EXIT(NETVSC);
787
788         return 0;
789 }
790
791 static int
792 RndisFilterInitDevice(
793         RNDIS_DEVICE    *Device
794         )
795 {
796         RNDIS_REQUEST *request;
797         RNDIS_INITIALIZE_REQUEST *init;
798         RNDIS_INITIALIZE_COMPLETE *initComplete;
799         u32 status;
800         int ret;
801
802         DPRINT_ENTER(NETVSC);
803
804         request = GetRndisRequest(Device, REMOTE_NDIS_INITIALIZE_MSG, RNDIS_MESSAGE_SIZE(RNDIS_INITIALIZE_REQUEST));
805         if (!request)
806         {
807                 ret = -1;
808                 goto Cleanup;
809         }
810
811         // Setup the rndis set
812         init = &request->RequestMessage.Message.InitializeRequest;
813         init->MajorVersion = RNDIS_MAJOR_VERSION;
814         init->MinorVersion = RNDIS_MINOR_VERSION;
815         init->MaxTransferSize = 2048; // FIXME: Use 1536 - rounded ethernet frame size
816
817         Device->State = RNDIS_DEV_INITIALIZING;
818
819         ret = RndisFilterSendRequest(Device, request);
820         if (ret != 0)
821         {
822                 Device->State = RNDIS_DEV_UNINITIALIZED;
823                 goto Cleanup;
824         }
825
826         WaitEventWait(request->WaitEvent);
827
828         initComplete = &request->ResponseMessage.Message.InitializeComplete;
829         status = initComplete->Status;
830         if (status == RNDIS_STATUS_SUCCESS)
831         {
832                 Device->State = RNDIS_DEV_INITIALIZED;
833                 ret = 0;
834         }
835         else
836         {
837                 Device->State = RNDIS_DEV_UNINITIALIZED;
838                 ret = -1;
839         }
840
841 Cleanup:
842         if (request)
843         {
844                 PutRndisRequest(Device, request);
845         }
846         DPRINT_EXIT(NETVSC);
847
848         return ret;
849 }
850
851 static void
852 RndisFilterHaltDevice(
853         RNDIS_DEVICE    *Device
854         )
855 {
856         RNDIS_REQUEST *request;
857         RNDIS_HALT_REQUEST *halt;
858
859         DPRINT_ENTER(NETVSC);
860
861         // Attempt to do a rndis device halt
862         request = GetRndisRequest(Device, REMOTE_NDIS_HALT_MSG, RNDIS_MESSAGE_SIZE(RNDIS_HALT_REQUEST));
863         if (!request)
864         {
865                 goto Cleanup;
866         }
867
868         // Setup the rndis set
869         halt = &request->RequestMessage.Message.HaltRequest;
870         halt->RequestId = InterlockedIncrement((int*)&Device->NewRequestId);
871
872         // Ignore return since this msg is optional.
873         RndisFilterSendRequest(Device, request);
874
875         Device->State = RNDIS_DEV_UNINITIALIZED;
876
877 Cleanup:
878         if (request)
879         {
880                 PutRndisRequest(Device, request);
881         }
882         DPRINT_EXIT(NETVSC);
883         return;
884 }
885
886
887 static int
888 RndisFilterOpenDevice(
889         RNDIS_DEVICE    *Device
890         )
891 {
892         int ret=0;
893
894         DPRINT_ENTER(NETVSC);
895
896         if (Device->State != RNDIS_DEV_INITIALIZED)
897                 return 0;
898
899         ret = RndisFilterSetPacketFilter(Device, NDIS_PACKET_TYPE_BROADCAST|NDIS_PACKET_TYPE_DIRECTED);
900         if (ret == 0)
901         {
902                 Device->State = RNDIS_DEV_DATAINITIALIZED;
903         }
904
905         DPRINT_EXIT(NETVSC);
906         return ret;
907 }
908
909 static int
910 RndisFilterCloseDevice(
911         RNDIS_DEVICE            *Device
912         )
913 {
914         int ret;
915
916         DPRINT_ENTER(NETVSC);
917
918         if (Device->State != RNDIS_DEV_DATAINITIALIZED)
919                 return 0;
920
921         ret = RndisFilterSetPacketFilter(Device, 0);
922         if (ret == 0)
923         {
924                 Device->State = RNDIS_DEV_INITIALIZED;
925         }
926
927         DPRINT_EXIT(NETVSC);
928
929         return ret;
930 }
931
932
933 int
934 RndisFilterOnDeviceAdd(
935         DEVICE_OBJECT   *Device,
936         void                    *AdditionalInfo
937         )
938 {
939         int ret;
940         NETVSC_DEVICE *netDevice;
941         RNDIS_DEVICE *rndisDevice;
942         NETVSC_DEVICE_INFO *deviceInfo = (NETVSC_DEVICE_INFO*)AdditionalInfo;
943
944         DPRINT_ENTER(NETVSC);
945
946         rndisDevice = GetRndisDevice();
947         if (!rndisDevice)
948         {
949                 DPRINT_EXIT(NETVSC);
950                 return -1;
951         }
952
953         DPRINT_DBG(NETVSC, "rndis device object allocated - %p", rndisDevice);
954
955         // Let the inner driver handle this first to create the netvsc channel
956         // NOTE! Once the channel is created, we may get a receive callback
957         // (RndisFilterOnReceive()) before this call is completed
958         ret = gRndisFilter.InnerDriver.Base.OnDeviceAdd(Device, AdditionalInfo);
959         if (ret != 0)
960         {
961                 PutRndisDevice(rndisDevice);
962                 DPRINT_EXIT(NETVSC);
963                 return ret;
964         }
965
966         //
967         // Initialize the rndis device
968         //
969         netDevice = (NETVSC_DEVICE*)Device->Extension;
970         ASSERT(netDevice);
971         ASSERT(netDevice->Device);
972
973         netDevice->Extension = rndisDevice;
974         rndisDevice->NetDevice = netDevice;
975
976         // Send the rndis initialization message
977         ret = RndisFilterInitDevice(rndisDevice);
978         if (ret != 0)
979         {
980                 // TODO: If rndis init failed, we will need to shut down the channel
981         }
982
983         // Get the mac address
984         ret = RndisFilterQueryDeviceMac(rndisDevice);
985         if (ret != 0)
986         {
987                 // TODO: shutdown rndis device and the channel
988         }
989
990         DPRINT_INFO(NETVSC, "Device 0x%p mac addr %02x%02x%02x%02x%02x%02x",
991                                 rndisDevice,
992                                 rndisDevice->HwMacAddr[0],
993                                 rndisDevice->HwMacAddr[1],
994                                 rndisDevice->HwMacAddr[2],
995                                 rndisDevice->HwMacAddr[3],
996                                 rndisDevice->HwMacAddr[4],
997                                 rndisDevice->HwMacAddr[5]);
998
999         memcpy(deviceInfo->MacAddr, rndisDevice->HwMacAddr, HW_MACADDR_LEN);
1000
1001         RndisFilterQueryDeviceLinkStatus(rndisDevice);
1002
1003         deviceInfo->LinkState = rndisDevice->LinkStatus;
1004         DPRINT_INFO(NETVSC, "Device 0x%p link state %s", rndisDevice, ((deviceInfo->LinkState)?("down"):("up")));
1005
1006         DPRINT_EXIT(NETVSC);
1007
1008         return ret;
1009 }
1010
1011
1012 static int
1013 RndisFilterOnDeviceRemove(
1014         DEVICE_OBJECT *Device
1015         )
1016 {
1017         NETVSC_DEVICE *netDevice = (NETVSC_DEVICE*)Device->Extension;
1018         RNDIS_DEVICE *rndisDevice = (RNDIS_DEVICE*)netDevice->Extension;
1019
1020         DPRINT_ENTER(NETVSC);
1021
1022         // Halt and release the rndis device
1023         RndisFilterHaltDevice(rndisDevice);
1024
1025         PutRndisDevice(rndisDevice);
1026         netDevice->Extension = NULL;
1027
1028         // Pass control to inner driver to remove the device
1029         gRndisFilter.InnerDriver.Base.OnDeviceRemove(Device);
1030
1031         DPRINT_EXIT(NETVSC);
1032
1033         return 0;
1034 }
1035
1036
1037 static void
1038 RndisFilterOnCleanup(
1039         DRIVER_OBJECT *Driver
1040         )
1041 {
1042         DPRINT_ENTER(NETVSC);
1043
1044         DPRINT_EXIT(NETVSC);
1045 }
1046
1047 static int
1048 RndisFilterOnOpen(
1049         DEVICE_OBJECT           *Device
1050         )
1051 {
1052         int ret;
1053         NETVSC_DEVICE *netDevice = (NETVSC_DEVICE*)Device->Extension;
1054
1055         DPRINT_ENTER(NETVSC);
1056
1057         ASSERT(netDevice);
1058         ret = RndisFilterOpenDevice((RNDIS_DEVICE*)netDevice->Extension);
1059
1060         DPRINT_EXIT(NETVSC);
1061
1062         return ret;
1063 }
1064
1065 static int
1066 RndisFilterOnClose(
1067         DEVICE_OBJECT           *Device
1068         )
1069 {
1070         int ret;
1071         NETVSC_DEVICE *netDevice = (NETVSC_DEVICE*)Device->Extension;
1072
1073         DPRINT_ENTER(NETVSC);
1074
1075         ASSERT(netDevice);
1076         ret = RndisFilterCloseDevice((RNDIS_DEVICE*)netDevice->Extension);
1077
1078         DPRINT_EXIT(NETVSC);
1079
1080         return ret;
1081 }
1082
1083
1084 static int
1085 RndisFilterOnSend(
1086         DEVICE_OBJECT           *Device,
1087         NETVSC_PACKET           *Packet
1088         )
1089 {
1090         int ret=0;
1091         RNDIS_FILTER_PACKET *filterPacket;
1092         RNDIS_MESSAGE *rndisMessage;
1093         RNDIS_PACKET *rndisPacket;
1094         u32 rndisMessageSize;
1095
1096         DPRINT_ENTER(NETVSC);
1097
1098         // Add the rndis header
1099         filterPacket = (RNDIS_FILTER_PACKET*)Packet->Extension;
1100         ASSERT(filterPacket);
1101
1102         memset(filterPacket, 0, sizeof(RNDIS_FILTER_PACKET));
1103
1104         rndisMessage = &filterPacket->Message;
1105         rndisMessageSize = RNDIS_MESSAGE_SIZE(RNDIS_PACKET);
1106
1107         rndisMessage->NdisMessageType = REMOTE_NDIS_PACKET_MSG;
1108         rndisMessage->MessageLength = Packet->TotalDataBufferLength + rndisMessageSize;
1109
1110         rndisPacket = &rndisMessage->Message.Packet;
1111         rndisPacket->DataOffset = sizeof(RNDIS_PACKET);
1112         rndisPacket->DataLength = Packet->TotalDataBufferLength;
1113
1114         Packet->IsDataPacket = true;
1115         Packet->PageBuffers[0].Pfn              = GetPhysicalAddress(rndisMessage) >> PAGE_SHIFT;
1116         Packet->PageBuffers[0].Offset   = (unsigned long)rndisMessage & (PAGE_SIZE-1);
1117         Packet->PageBuffers[0].Length   = rndisMessageSize;
1118
1119         // Save the packet send completion and context
1120         filterPacket->OnCompletion = Packet->Completion.Send.OnSendCompletion;
1121         filterPacket->CompletionContext = Packet->Completion.Send.SendCompletionContext;
1122
1123         // Use ours
1124         Packet->Completion.Send.OnSendCompletion = RndisFilterOnSendCompletion;
1125         Packet->Completion.Send.SendCompletionContext = filterPacket;
1126
1127         ret = gRndisFilter.InnerDriver.OnSend(Device, Packet);
1128         if (ret != 0)
1129         {
1130                 // Reset the completion to originals to allow retries from above
1131                 Packet->Completion.Send.OnSendCompletion = filterPacket->OnCompletion;
1132                 Packet->Completion.Send.SendCompletionContext = filterPacket->CompletionContext;
1133         }
1134
1135         DPRINT_EXIT(NETVSC);
1136
1137         return ret;
1138 }
1139
1140 static void
1141 RndisFilterOnSendCompletion(
1142    void *Context)
1143 {
1144         RNDIS_FILTER_PACKET *filterPacket = (RNDIS_FILTER_PACKET *)Context;
1145
1146         DPRINT_ENTER(NETVSC);
1147
1148         // Pass it back to the original handler
1149         filterPacket->OnCompletion(filterPacket->CompletionContext);
1150
1151         DPRINT_EXIT(NETVSC);
1152 }
1153
1154
1155 static void
1156 RndisFilterOnSendRequestCompletion(
1157    void *Context
1158    )
1159 {
1160         DPRINT_ENTER(NETVSC);
1161
1162         // Noop
1163         DPRINT_EXIT(NETVSC);
1164 }