diff options
author | Simone Bordet | 2013-02-27 16:30:22 +0000 |
---|---|---|
committer | Simone Bordet | 2013-02-27 16:30:22 +0000 |
commit | f28bd1d010902717ec1c56f27c5e3529021ed02d (patch) | |
tree | db20bdaee8eb140986a82242907ee0795069a4a8 /jetty-servlets | |
parent | 1627b6eaf74b156f224bedd9a7741595667e4b45 (diff) | |
parent | 90bab0eb6624619b588d0be2c2d574ab6ce9641f (diff) | |
download | org.eclipse.jetty.project-f28bd1d010902717ec1c56f27c5e3529021ed02d.tar.gz org.eclipse.jetty.project-f28bd1d010902717ec1c56f27c5e3529021ed02d.tar.xz org.eclipse.jetty.project-f28bd1d010902717ec1c56f27c5e3529021ed02d.zip |
Merged branch jetty-7.
Diffstat (limited to 'jetty-servlets')
5 files changed, 531 insertions, 390 deletions
diff --git a/jetty-servlets/pom.xml b/jetty-servlets/pom.xml index ff74704917..96d92402d7 100644 --- a/jetty-servlets/pom.xml +++ b/jetty-servlets/pom.xml @@ -91,6 +91,12 @@ </dependency> <dependency> <groupId>org.eclipse.jetty</groupId> + <artifactId>jetty-jmx</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.eclipse.jetty</groupId> <artifactId>test-jetty-servlet</artifactId> <version>${project.version}</version> <scope>test</scope> diff --git a/jetty-servlets/src/main/java/org/eclipse/jetty/servlets/DoSFilter.java b/jetty-servlets/src/main/java/org/eclipse/jetty/servlets/DoSFilter.java index be3517d6ef..027eb6dcfa 100644 --- a/jetty-servlets/src/main/java/org/eclipse/jetty/servlets/DoSFilter.java +++ b/jetty-servlets/src/main/java/org/eclipse/jetty/servlets/DoSFilter.java @@ -19,16 +19,17 @@ package org.eclipse.jetty.servlets; import java.io.IOException; -import java.io.Serializable; -import java.util.HashSet; -import java.util.Map; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; import java.util.Queue; -import java.util.StringTokenizer; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; - +import java.util.regex.Matcher; +import java.util.regex.Pattern; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; @@ -54,9 +55,9 @@ import org.eclipse.jetty.util.thread.Timeout; /** * Denial of Service filter - * + * <p/> * <p> - * This filter is based on the {@link QoSFilter}. it is useful for limiting + * This filter is useful for limiting * exposure to abuse from request flooding, whether malicious, or as a result of * a misconfigured client. * <p> @@ -73,111 +74,109 @@ import org.eclipse.jetty.util.thread.Timeout; * implemented, in order to uniquely identify authenticated users. * <p> * The following init parameters control the behavior of the filter:<dl> - * + * <p/> * <dt>maxRequestsPerSec</dt> - * <dd>the maximum number of requests from a connection per - * second. Requests in excess of this are first delayed, - * then throttled.</dd> - * + * <dd>the maximum number of requests from a connection per + * second. Requests in excess of this are first delayed, + * then throttled.</dd> + * <p/> * <dt>delayMs</dt> - * <dd>is the delay given to all requests over the rate limit, - * before they are considered at all. -1 means just reject request, - * 0 means no delay, otherwise it is the delay.</dd> - * + * <dd>is the delay given to all requests over the rate limit, + * before they are considered at all. -1 means just reject request, + * 0 means no delay, otherwise it is the delay.</dd> + * <p/> * <dt>maxWaitMs</dt> - * <dd>how long to blocking wait for the throttle semaphore.</dd> - * + * <dd>how long to blocking wait for the throttle semaphore.</dd> + * <p/> * <dt>throttledRequests</dt> - * <dd>is the number of requests over the rate limit able to be - * considered at once.</dd> - * + * <dd>is the number of requests over the rate limit able to be + * considered at once.</dd> + * <p/> * <dt>throttleMs</dt> - * <dd>how long to async wait for semaphore.</dd> - * + * <dd>how long to async wait for semaphore.</dd> + * <p/> * <dt>maxRequestMs</dt> - * <dd>how long to allow this request to run.</dd> - * + * <dd>how long to allow this request to run.</dd> + * <p/> * <dt>maxIdleTrackerMs</dt> - * <dd>how long to keep track of request rates for a connection, - * before deciding that the user has gone away, and discarding it</dd> - * + * <dd>how long to keep track of request rates for a connection, + * before deciding that the user has gone away, and discarding it</dd> + * <p/> * <dt>insertHeaders</dt> - * <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd> - * + * <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd> + * <p/> * <dt>trackSessions</dt> - * <dd>if true, usage rate is tracked by session if a session exists. Defaults to true.</dd> - * + * <dd>if true, usage rate is tracked by session if a session exists. Defaults to true.</dd> + * <p/> * <dt>remotePort</dt> - * <dd>if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.</dd> - * + * <dd>if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.</dd> + * <p/> * <dt>ipWhitelist</dt> - * <dd>a comma-separated list of IP addresses that will not be rate limited</dd> - * + * <dd>a comma-separated list of IP addresses that will not be rate limited</dd> + * <p/> * <dt>managedAttr</dt> - * <dd>if set to true, then this servlet is set as a {@link ServletContext} attribute with the + * <dd>if set to true, then this servlet is set as a {@link ServletContext} attribute with the * filter name as the attribute name. This allows context external mechanism (eg JMX via {@link ContextHandler#MANAGED_ATTRIBUTES}) to * manage the configuration of the filter.</dd> * </dl> * </p> */ - public class DoSFilter implements Filter { private static final Logger LOG = Log.getLogger(DoSFilter.class); - final static String __TRACKER = "DoSFilter.Tracker"; - final static String __THROTTLED = "DoSFilter.Throttled"; - - final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25; - final static int __DEFAULT_DELAY_MS = 100; - final static int __DEFAULT_THROTTLE = 5; - final static int __DEFAULT_WAIT_MS=50; - final static long __DEFAULT_THROTTLE_MS = 30000L; - final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L; - final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L; - - final static String MANAGED_ATTR_INIT_PARAM="managedAttr"; - final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec"; - final static String DELAY_MS_INIT_PARAM = "delayMs"; - final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests"; - final static String MAX_WAIT_INIT_PARAM="maxWaitMs"; - final static String THROTTLE_MS_INIT_PARAM = "throttleMs"; - final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs"; - final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs"; - final static String INSERT_HEADERS_INIT_PARAM="insertHeaders"; - final static String TRACK_SESSIONS_INIT_PARAM="trackSessions"; - final static String REMOTE_PORT_INIT_PARAM="remotePort"; - final static String IP_WHITELIST_INIT_PARAM="ipWhitelist"; - - final static int USER_AUTH = 2; - final static int USER_SESSION = 2; - final static int USER_IP = 1; - final static int USER_UNKNOWN = 0; - - ServletContext _context; - - protected String _name; - protected long _delayMs; - protected long _throttleMs; - protected long _maxWaitMs; - protected long _maxRequestMs; - protected long _maxIdleTrackerMs; - protected boolean _insertHeaders; - protected boolean _trackSessions; - protected boolean _remotePort; - protected int _throttledRequests; - protected Semaphore _passes; - protected Queue<Continuation>[] _queue; - protected ContinuationListener[] _listener; - - protected int _maxRequestsPerSec; - protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>(); - protected String _whitelistStr; - private final HashSet<String> _whitelist = new HashSet<String>(); - + private static final Pattern IP_PATTERN = Pattern.compile("(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})"); + private static final Pattern CIDR_PATTERN = Pattern.compile(IP_PATTERN + "/(\\d{1,2})"); + + private static final String __TRACKER = "DoSFilter.Tracker"; + private static final String __THROTTLED = "DoSFilter.Throttled"; + + private static final int __DEFAULT_MAX_REQUESTS_PER_SEC = 25; + private static final int __DEFAULT_DELAY_MS = 100; + private static final int __DEFAULT_THROTTLE = 5; + private static final int __DEFAULT_MAX_WAIT_MS = 50; + private static final long __DEFAULT_THROTTLE_MS = 30000L; + private static final long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM = 30000L; + private static final long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM = 30000L; + + static final String MANAGED_ATTR_INIT_PARAM = "managedAttr"; + static final String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec"; + static final String DELAY_MS_INIT_PARAM = "delayMs"; + static final String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests"; + static final String MAX_WAIT_INIT_PARAM = "maxWaitMs"; + static final String THROTTLE_MS_INIT_PARAM = "throttleMs"; + static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs"; + static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs"; + static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders"; + static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions"; + static final String REMOTE_PORT_INIT_PARAM = "remotePort"; + static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist"; + static final String ENABLED_INIT_PARAM = "enabled"; + + private static final int USER_AUTH = 2; + private static final int USER_SESSION = 2; + private static final int USER_IP = 1; + private static final int USER_UNKNOWN = 0; + + private ServletContext _context; + private volatile long _delayMs; + private volatile long _throttleMs; + private volatile long _maxWaitMs; + private volatile long _maxRequestMs; + private volatile long _maxIdleTrackerMs; + private volatile boolean _insertHeaders; + private volatile boolean _trackSessions; + private volatile boolean _remotePort; + private volatile boolean _enabled; + private Semaphore _passes; + private volatile int _throttledRequests; + private volatile int _maxRequestsPerSec; + private Queue<Continuation>[] _queue; + private ContinuationListener[] _listeners; + private final ConcurrentHashMap<String, RateTracker> _rateTrackers = new ConcurrentHashMap<String, RateTracker>(); + private final List<String> _whitelist = new CopyOnWriteArrayList<String>(); private final Timeout _requestTimeoutQ = new Timeout(); private final Timeout _trackerTimeoutQ = new Timeout(); - private Thread _timerThread; private volatile boolean _running; @@ -186,13 +185,13 @@ public class DoSFilter implements Filter _context = filterConfig.getServletContext(); _queue = new Queue[getMaxPriority() + 1]; - _listener = new ContinuationListener[getMaxPriority() + 1]; + _listeners = new ContinuationListener[getMaxPriority() + 1]; for (int p = 0; p < _queue.length; p++) { _queue[p] = new ConcurrentLinkedQueue<Continuation>(); - final int priority=p; - _listener[p] = new ContinuationListener() + final int priority = p; + _listeners[p] = new ContinuationListener() { public void onComplete(Continuation continuation) { @@ -207,55 +206,65 @@ public class DoSFilter implements Filter _rateTrackers.clear(); - int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC; - if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null) - baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM)); - _maxRequestsPerSec = baseRateLimit; + int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC; + String parameter = filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM); + if (parameter != null) + maxRequests = Integer.parseInt(parameter); + setMaxRequestsPerSec(maxRequests); long delay = __DEFAULT_DELAY_MS; - if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null) - delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM)); - _delayMs = delay; + parameter = filterConfig.getInitParameter(DELAY_MS_INIT_PARAM); + if (parameter != null) + delay = Long.parseLong(parameter); + setDelayMs(delay); int throttledRequests = __DEFAULT_THROTTLE; - if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null) - throttledRequests = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM)); - _passes = new Semaphore(throttledRequests,true); - _throttledRequests = throttledRequests; - - long wait = __DEFAULT_WAIT_MS; - if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null) - wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM)); - _maxWaitMs = wait; - - long suspend = __DEFAULT_THROTTLE_MS; - if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null) - suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM)); - _throttleMs = suspend; + parameter = filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM); + if (parameter != null) + throttledRequests = Integer.parseInt(parameter); + setThrottledRequests(throttledRequests); + + long maxWait = __DEFAULT_MAX_WAIT_MS; + parameter = filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM); + if (parameter != null) + maxWait = Long.parseLong(parameter); + setMaxWaitMs(maxWait); + + long throttle = __DEFAULT_THROTTLE_MS; + parameter = filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM); + if (parameter != null) + throttle = Long.parseLong(parameter); + setThrottleMs(throttle); long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM; - if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null ) - maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM)); - _maxRequestMs = maxRequestMs; + parameter = filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM); + if (parameter != null) + maxRequestMs = Long.parseLong(parameter); + setMaxRequestMs(maxRequestMs); long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM; - if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null ) - maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM)); - _maxIdleTrackerMs = maxIdleTrackerMs; + parameter = filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM); + if (parameter != null) + maxIdleTrackerMs = Long.parseLong(parameter); + setMaxIdleTrackerMs(maxIdleTrackerMs); + + String whiteList = ""; + parameter = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM); + if (parameter != null) + whiteList = parameter; + setWhitelist(whiteList); - _whitelistStr = ""; - if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null ) - _whitelistStr = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM); - initWhitelist(); + parameter = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM); + setInsertHeaders(parameter == null || Boolean.parseBoolean(parameter)); - String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM); - _insertHeaders = tmp==null || Boolean.parseBoolean(tmp); + parameter = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM); + setTrackSessions(parameter == null || Boolean.parseBoolean(parameter)); - tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM); - _trackSessions = tmp==null || Boolean.parseBoolean(tmp); + parameter = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM); + setRemotePort(parameter != null && Boolean.parseBoolean(parameter)); - tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM); - _remotePort = tmp!=null&& Boolean.parseBoolean(tmp); + parameter = filterConfig.getInitParameter(ENABLED_INIT_PARAM); + setEnabled(parameter == null || Boolean.parseBoolean(parameter)); _requestTimeoutQ.setNow(); _requestTimeoutQ.setDuration(_maxRequestMs); @@ -263,7 +272,7 @@ public class DoSFilter implements Filter _trackerTimeoutQ.setNow(); _trackerTimeoutQ.setDuration(_maxIdleTrackerMs); - _running=true; + _running = true; _timerThread = (new Thread() { public void run() @@ -272,17 +281,10 @@ public class DoSFilter implements Filter { while (_running) { - long now; - synchronized (_requestTimeoutQ) - { - now = _requestTimeoutQ.setNow(); - _requestTimeoutQ.tick(); - } - synchronized (_trackerTimeoutQ) - { - _trackerTimeoutQ.setNow(now); - _trackerTimeoutQ.tick(); - } + long now = _requestTimeoutQ.setNow(); + _requestTimeoutQ.tick(); + _trackerTimeoutQ.setNow(now); + _trackerTimeoutQ.tick(); try { Thread.sleep(100); @@ -295,28 +297,35 @@ public class DoSFilter implements Filter } finally { - LOG.info("DoSFilter timer exited"); + LOG.debug("DoSFilter timer exited"); } } }); _timerThread.start(); - if (_context!=null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM))) - _context.setAttribute(filterConfig.getFilterName(),this); + if (_context != null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM))) + _context.setAttribute(filterConfig.getFilterName(), this); } + public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException + { + doFilter((HttpServletRequest)request, (HttpServletResponse)response, filterChain); + } - public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException + protected void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException { - final HttpServletRequest srequest = (HttpServletRequest)request; - final HttpServletResponse sresponse = (HttpServletResponse)response; + if (!isEnabled()) + { + filterChain.doFilter(request, response); + return; + } - final long now=_requestTimeoutQ.getNow(); + final long now = _requestTimeoutQ.getNow(); // Look for the rate tracker for this request RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER); - if (tracker==null) + if (tracker == null) { // This is the first time we have seen this request. @@ -329,62 +338,53 @@ public class DoSFilter implements Filter // pass it through if we are not currently over the rate limit if (!overRateLimit) { - doFilterChain(filterchain,srequest,sresponse); + doFilterChain(filterChain, request, response); return; } // We are over the limit. - LOG.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal()); + LOG.warn("DOS ALERT: ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal()); // So either reject it, delay it or throttle it - switch((int)_delayMs) + long delayMs = getDelayMs(); + boolean insertHeaders = isInsertHeaders(); + switch ((int)delayMs) { case -1: { // Reject this request - if (_insertHeaders) - ((HttpServletResponse)response).addHeader("DoSFilter","unavailable"); - - ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE); + if (insertHeaders) + response.addHeader("DoSFilter", "unavailable"); + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE); return; } case 0: { - // fall through to throttle code - request.setAttribute(__TRACKER,tracker); + // fall through to throttle code + request.setAttribute(__TRACKER, tracker); break; } default: - { + { // insert a delay before throttling the request - if (_insertHeaders) - ((HttpServletResponse)response).addHeader("DoSFilter","delayed"); + if (insertHeaders) + response.addHeader("DoSFilter", "delayed"); Continuation continuation = ContinuationSupport.getContinuation(request); - request.setAttribute(__TRACKER,tracker); - if (_delayMs > 0) - continuation.setTimeout(_delayMs); - continuation.addContinuationListener(new ContinuationListener() - { - - public void onComplete(Continuation continuation) - { - } - - public void onTimeout(Continuation continuation) - { - } - }); + request.setAttribute(__TRACKER, tracker); + if (delayMs > 0) + continuation.setTimeout(delayMs); continuation.suspend(); return; } } } + // Throttle the request boolean accepted = false; try { // check if we can afford to accept another request at this time - accepted = _passes.tryAcquire(_maxWaitMs,TimeUnit.MILLISECONDS); + accepted = _passes.tryAcquire(getMaxWaitMs(), TimeUnit.MILLISECONDS); if (!accepted) { @@ -392,22 +392,23 @@ public class DoSFilter implements Filter final Continuation continuation = ContinuationSupport.getContinuation(request); Boolean throttled = (Boolean)request.getAttribute(__THROTTLED); - if (throttled!=Boolean.TRUE && _throttleMs>0) + long throttleMs = getThrottleMs(); + if (throttled != Boolean.TRUE && throttleMs > 0) { - int priority = getPriority(request,tracker); - request.setAttribute(__THROTTLED,Boolean.TRUE); - if (_insertHeaders) - ((HttpServletResponse)response).addHeader("DoSFilter","throttled"); - if (_throttleMs > 0) - continuation.setTimeout(_throttleMs); + int priority = getPriority(request, tracker); + request.setAttribute(__THROTTLED, Boolean.TRUE); + if (isInsertHeaders()) + response.addHeader("DoSFilter", "throttled"); + if (throttleMs > 0) + continuation.setTimeout(throttleMs); continuation.suspend(); - continuation.addContinuationListener(_listener[priority]); + continuation.addContinuationListener(_listeners[priority]); _queue[priority].add(continuation); return; } // else were we resumed? - else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE) + else if (request.getAttribute("javax.servlet.resumed") == Boolean.TRUE) { // we were resumed and somebody stole our pass, so we wait for the next one. _passes.acquire(); @@ -418,30 +419,26 @@ public class DoSFilter implements Filter // if we were accepted (either immediately or after throttle) if (accepted) // call the chain - doFilterChain(filterchain,srequest,sresponse); + doFilterChain(filterChain, request, response); else { // fail the request - if (_insertHeaders) - ((HttpServletResponse)response).addHeader("DoSFilter","unavailable"); - ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE); + if (isInsertHeaders()) + response.addHeader("DoSFilter", "unavailable"); + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE); } } catch (InterruptedException e) { - _context.log("DoS",e); - ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE); - } - catch (Exception e) - { - e.printStackTrace(); + _context.log("DoS", e); + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE); } finally { if (accepted) { // wake up the next highest priority request. - for (int p = _queue.length; p-- > 0;) + for (int p = _queue.length; p-- > 0; ) { Continuation continuation = _queue[p].poll(); if (continuation != null && continuation.isSuspended()) @@ -455,17 +452,9 @@ public class DoSFilter implements Filter } } - /** - * @param chain - * @param request - * @param response - * @throws IOException - * @throws ServletException - */ - protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response) - throws IOException, ServletException + protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response) throws IOException, ServletException { - final Thread thread=Thread.currentThread(); + final Thread thread = Thread.currentThread(); final Timeout.Task requestTimeout = new Timeout.Task() { @@ -477,32 +466,27 @@ public class DoSFilter implements Filter try { - synchronized (_requestTimeoutQ) - { - _requestTimeoutQ.schedule(requestTimeout); - } - chain.doFilter(request,response); + _requestTimeoutQ.schedule(requestTimeout); + chain.doFilter(request, response); } finally { - synchronized (_requestTimeoutQ) - { - requestTimeout.cancel(); - } + requestTimeout.cancel(); } } /** * Takes drastic measures to return this response and stop this thread. * Due to the way the connection is interrupted, may return mixed up headers. - * @param request current request + * + * @param request current request * @param response current response, which must be stopped - * @param thread the handling thread + * @param thread the handling thread */ protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread) { // take drastic measures to return this response and stop this thread. - if( !response.isCommitted() ) + if (!response.isCommitted()) { response.setHeader("Connection", "close"); } @@ -529,15 +513,15 @@ public class DoSFilter implements Filter /** * Get priority for this request, based on user type * - * @param request - * @param tracker - * @return priority + * @param request the current request + * @param tracker the rate tracker for this request + * @return the priority for this request */ - protected int getPriority(ServletRequest request, RateTracker tracker) + protected int getPriority(HttpServletRequest request, RateTracker tracker) { - if (extractUserId(request)!=null) + if (extractUserId(request) != null) return USER_AUTH; - if (tracker!=null) + if (tracker != null) return tracker.getType(); return USER_UNKNOWN; } @@ -555,21 +539,20 @@ public class DoSFilter implements Filter * track of this connection's request rate. If this is not the first request * from this connection, return the existing object with the stored stats. * If it is the first request, then create a new request tracker. - * + * <p/> * Assumes that each connection has an identifying characteristic, and goes * through them in order, taking the first that matches: user id (logged * in), session id, client IP address. Unidentifiable connections are lumped * into one. - * + * <p/> * When a session expires, its rate tracker is automatically deleted. * - * @param request + * @param request the current request * @return the request rate tracker for the current connection */ public RateTracker getRateTracker(ServletRequest request) { - HttpServletRequest srequest = (HttpServletRequest)request; - HttpSession session=srequest.getSession(false); + HttpSession session = ((HttpServletRequest)request).getSession(false); String loadId = extractUserId(request); final int type; @@ -579,64 +562,95 @@ public class DoSFilter implements Filter } else { - if (_trackSessions && session!=null && !session.isNew()) + if (_trackSessions && session != null && !session.isNew()) { - loadId=session.getId(); + loadId = session.getId(); type = USER_SESSION; } else { - loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr(); + loadId = _remotePort ? (request.getRemoteAddr() + request.getRemotePort()) : request.getRemoteAddr(); type = USER_IP; } } - RateTracker tracker=_rateTrackers.get(loadId); + RateTracker tracker = _rateTrackers.get(loadId); - if (tracker==null) + if (tracker == null) { - RateTracker t; - if (_whitelist.contains(request.getRemoteAddr())) - { - t = new FixedRateTracker(loadId,type,_maxRequestsPerSec); - } - else - { - t = new RateTracker(loadId,type,_maxRequestsPerSec); - } - - tracker=_rateTrackers.putIfAbsent(loadId,t); - if (tracker==null) - tracker=t; + boolean allowed = checkWhitelist(_whitelist, request.getRemoteAddr()); + tracker = allowed ? new FixedRateTracker(loadId, type, _maxRequestsPerSec) + : new RateTracker(loadId, type, _maxRequestsPerSec); + RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker); + if (existing != null) + tracker = existing; if (type == USER_IP) { // USER_IP expiration from _rateTrackers is handled by the _trackerTimeoutQ - synchronized (_trackerTimeoutQ) - { - _trackerTimeoutQ.schedule(tracker); - } + _trackerTimeoutQ.schedule(tracker); } - else if (session!=null) + else if (session != null) + { // USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener - session.setAttribute(__TRACKER,tracker); + session.setAttribute(__TRACKER, tracker); + } } return tracker; } - public void destroy() + protected boolean checkWhitelist(List<String> whitelist, String candidate) { - _running=false; - _timerThread.interrupt(); - synchronized (_requestTimeoutQ) + for (String address : whitelist) { - _requestTimeoutQ.cancelAll(); + if (address.contains("/")) + { + if (subnetMatch(address, candidate)) + return true; + } + else + { + if (address.equals(candidate)) + return true; + } } - synchronized (_trackerTimeoutQ) + return false; + } + + protected boolean subnetMatch(String subnetAddress, String candidate) + { + Matcher matcher = CIDR_PATTERN.matcher(subnetAddress); + int subnet = intFromAddress(matcher); + int prefix = Integer.parseInt(matcher.group(5)); + // Sets the most significant prefix bits to 1 + // If prefix == 8 => 11111111_00000000_00000000_00000000 + int mask = ~((1 << (32 - prefix)) - 1); + int ip = intFromAddress(IP_PATTERN.matcher(candidate)); + return (ip & mask) == (subnet & mask); + } + + private int intFromAddress(Matcher matcher) + { + int result = 0; + if (matcher.matches()) { - _trackerTimeoutQ.cancelAll(); + for (int i = 0; i < 4; ++i) + { + int b = Integer.parseInt(matcher.group(i + 1)); + result |= b << 8 * (3 - i); + } + return result; } + throw new IllegalStateException(); + } + + public void destroy() + { + _running = false; + _timerThread.interrupt(); + _requestTimeoutQ.cancelAll(); + _trackerTimeoutQ.cancelAll(); _rateTrackers.clear(); _whitelist.clear(); } @@ -645,7 +659,7 @@ public class DoSFilter implements Filter * Returns the user id, used to track this connection. * This SHOULD be overridden by subclasses. * - * @param request + * @param request the current request * @return a unique user id, if logged in; otherwise null. */ protected String extractUserId(ServletRequest request) @@ -653,26 +667,11 @@ public class DoSFilter implements Filter return null; } - /* ------------------------------------------------------------ */ /** - * Initialize the IP address whitelist - */ - protected void initWhitelist() - { - _whitelist.clear(); - StringTokenizer tokenizer = new StringTokenizer(_whitelistStr, ","); - while (tokenizer.hasMoreTokens()) - _whitelist.add(tokenizer.nextToken().trim()); - - LOG.info("Whitelisted IP addresses: {}", _whitelist.toString()); - } - - /* ------------------------------------------------------------ */ - /** * Get maximum number of requests from a connection per * second. Requests in excess of this are first delayed, * then throttled. - * + * * @return maximum number of requests */ public int getMaxRequestsPerSec() @@ -680,46 +679,42 @@ public class DoSFilter implements Filter return _maxRequestsPerSec; } - /* ------------------------------------------------------------ */ - /** + /** * Get maximum number of requests from a connection per * second. Requests in excess of this are first delayed, * then throttled. - * + * * @param value maximum number of requests */ public void setMaxRequestsPerSec(int value) { _maxRequestsPerSec = value; } - - /* ------------------------------------------------------------ */ + /** - * Get delay (in milliseconds) that is applied to all requests - * over the rate limit, before they are considered at all. + * Get delay (in milliseconds) that is applied to all requests + * over the rate limit, before they are considered at all. */ public long getDelayMs() { return _delayMs; } - /* ------------------------------------------------------------ */ /** - * Set delay (in milliseconds) that is applied to all requests - * over the rate limit, before they are considered at all. - * + * Set delay (in milliseconds) that is applied to all requests + * over the rate limit, before they are considered at all. + * * @param value delay (in milliseconds), 0 - no delay, -1 - reject request */ public void setDelayMs(long value) { _delayMs = value; } - - /* ------------------------------------------------------------ */ + /** * Get maximum amount of time (in milliseconds) the filter will * blocking wait for the throttle semaphore. - * + * * @return maximum wait time */ public long getMaxWaitMs() @@ -727,11 +722,10 @@ public class DoSFilter implements Filter return _maxWaitMs; } - /* ------------------------------------------------------------ */ - /** + /** * Set maximum amount of time (in milliseconds) the filter will * blocking wait for the throttle semaphore. - * + * * @param value maximum wait time */ public void setMaxWaitMs(long value) @@ -739,11 +733,10 @@ public class DoSFilter implements Filter _maxWaitMs = value; } - /* ------------------------------------------------------------ */ /** * Get number of requests over the rate limit able to be * considered at once. - * + * * @return number of requests */ public int getThrottledRequests() @@ -751,23 +744,22 @@ public class DoSFilter implements Filter return _throttledRequests; } - /* ------------------------------------------------------------ */ /** * Set number of requests over the rate limit able to be * considered at once. - * + * * @param value number of requests */ public void setThrottledRequests(int value) { - _passes = new Semaphore((value-_throttledRequests+_passes.availablePermits()), true); + int permits = _passes == null ? 0 : _passes.availablePermits(); + _passes = new Semaphore((value - _throttledRequests + permits), true); _throttledRequests = value; } - /* ------------------------------------------------------------ */ /** * Get amount of time (in milliseconds) to async wait for semaphore. - * + * * @return wait time */ public long getThrottleMs() @@ -775,10 +767,9 @@ public class DoSFilter implements Filter return _throttleMs; } - /* ------------------------------------------------------------ */ /** * Set amount of time (in milliseconds) to async wait for semaphore. - * + * * @param value wait time */ public void setThrottleMs(long value) @@ -786,11 +777,10 @@ public class DoSFilter implements Filter _throttleMs = value; } - /* ------------------------------------------------------------ */ - /** - * Get maximum amount of time (in milliseconds) to allow + /** + * Get maximum amount of time (in milliseconds) to allow * the request to process. - * + * * @return maximum processing time */ public long getMaxRequestMs() @@ -798,11 +788,10 @@ public class DoSFilter implements Filter return _maxRequestMs; } - /* ------------------------------------------------------------ */ - /** - * Set maximum amount of time (in milliseconds) to allow + /** + * Set maximum amount of time (in milliseconds) to allow * the request to process. - * + * * @param value maximum processing time */ public void setMaxRequestMs(long value) @@ -810,12 +799,11 @@ public class DoSFilter implements Filter _maxRequestMs = value; } - /* ------------------------------------------------------------ */ /** * Get maximum amount of time (in milliseconds) to keep track - * of request rates for a connection, before deciding that + * of request rates for a connection, before deciding that * the user has gone away, and discarding it. - * + * * @return maximum tracking time */ public long getMaxIdleTrackerMs() @@ -823,12 +811,11 @@ public class DoSFilter implements Filter return _maxIdleTrackerMs; } - /* ------------------------------------------------------------ */ - /** + /** * Set maximum amount of time (in milliseconds) to keep track - * of request rates for a connection, before deciding that + * of request rates for a connection, before deciding that * the user has gone away, and discarding it. - * + * * @param value maximum tracking time */ public void setMaxIdleTrackerMs(long value) @@ -836,10 +823,9 @@ public class DoSFilter implements Filter _maxIdleTrackerMs = value; } - /* ------------------------------------------------------------ */ - /** + /** * Check flag to insert the DoSFilter headers into the response. - * + * * @return value of the flag */ public boolean isInsertHeaders() @@ -847,10 +833,9 @@ public class DoSFilter implements Filter return _insertHeaders; } - /* ------------------------------------------------------------ */ - /** + /** * Set flag to insert the DoSFilter headers into the response. - * + * * @param value value of the flag */ public void setInsertHeaders(boolean value) @@ -858,10 +843,9 @@ public class DoSFilter implements Filter _insertHeaders = value; } - /* ------------------------------------------------------------ */ - /** + /** * Get flag to have usage rate tracked by session if a session exists. - * + * * @return value of the flag */ public boolean isTrackSessions() @@ -869,9 +853,9 @@ public class DoSFilter implements Filter return _trackSessions; } - /* ------------------------------------------------------------ */ - /** + /** * Set flag to have usage rate tracked by session if a session exists. + * * @param value value of the flag */ public void setTrackSessions(boolean value) @@ -879,11 +863,10 @@ public class DoSFilter implements Filter _trackSessions = value; } - /* ------------------------------------------------------------ */ - /** + /** * Get flag to have usage rate tracked by IP+port (effectively connection) * if session tracking is not used. - * + * * @return value of the flag */ public boolean isRemotePort() @@ -891,12 +874,10 @@ public class DoSFilter implements Filter return _remotePort; } - - /* ------------------------------------------------------------ */ - /** + /** * Set flag to have usage rate tracked by IP+port (effectively connection) * if session tracking is not used. - * + * * @param value value of the flag */ public void setRemotePort(boolean value) @@ -904,28 +885,81 @@ public class DoSFilter implements Filter _remotePort = value; } - /* ------------------------------------------------------------ */ + /** + * @return whether this filter is enabled + */ + public boolean isEnabled() + { + return _enabled; + } + + /** + * @param enabled whether this filter is enabled + */ + public void setEnabled(boolean enabled) + { + _enabled = enabled; + } + /** * Get a list of IP addresses that will not be rate limited. - * + * * @return comma-separated whitelist */ public String getWhitelist() { - return _whitelistStr; + StringBuilder result = new StringBuilder(); + for (Iterator<String> iterator = _whitelist.iterator(); iterator.hasNext();) + { + String address = iterator.next(); + result.append(address); + if (iterator.hasNext()) + result.append(","); + } + return result.toString(); } - - /* ------------------------------------------------------------ */ /** * Set a list of IP addresses that will not be rate limited. - * + * * @param value comma-separated whitelist */ public void setWhitelist(String value) { - _whitelistStr = value; - initWhitelist(); + List<String> result = new ArrayList<String>(); + for (String address : value.split(",")) + addWhitelistAddress(result, address); + _whitelist.clear(); + _whitelist.addAll(result); + LOG.debug("Whitelisted IP addresses: {}", result); + } + + public void clearWhitelist() + { + _whitelist.clear(); + } + + public boolean addWhitelistAddress(String address) + { + return addWhitelistAddress(_whitelist, address); + } + + private boolean addWhitelistAddress(List<String> list, String address) + { + address = address.trim(); + if (address.length() > 0) + { + if (CIDR_PATTERN.matcher(address).matches() || IP_PATTERN.matcher(address).matches()) + return list.add(address); + else + LOG.warn("Ignoring malformed whitelist IP address {}", address); + } + return false; + } + + public boolean removeWhitelistAddress(String address) + { + return _whitelist.remove(address); } /** @@ -938,14 +972,13 @@ public class DoSFilter implements Filter transient protected final int _type; transient protected final long[] _timestamps; transient protected int _next; - - - public RateTracker(String id, int type,int maxRequestsPerSecond) + + public RateTracker(String id, int type, int maxRequestsPerSecond) { _id = id; _type = type; - _timestamps=new long[maxRequestsPerSecond]; - _next=0; + _timestamps = new long[maxRequestsPerSecond]; + _next = 0; } /** @@ -956,16 +989,14 @@ public class DoSFilter implements Filter final long last; synchronized (this) { - last=_timestamps[_next]; - _timestamps[_next]=now; - _next= (_next+1)%_timestamps.length; + last = _timestamps[_next]; + _timestamps[_next] = now; + _next = (_next + 1) % _timestamps.length; } - boolean exceeded=last!=0 && (now-last)<1000L; - return exceeded; + return last != 0 && (now - last) < 1000L; } - public String getId() { return _id; @@ -976,67 +1007,59 @@ public class DoSFilter implements Filter return _type; } - public void valueBound(HttpSessionBindingEvent event) - { + { if (LOG.isDebugEnabled()) - LOG.debug("Value bound:"+_id); + LOG.debug("Value bound: {}", getId()); } public void valueUnbound(HttpSessionBindingEvent event) { //take the tracker out of the list of trackers - if (_rateTrackers != null) - _rateTrackers.remove(_id); - if (LOG.isDebugEnabled()) LOG.debug("Tracker removed: "+_id); + _rateTrackers.remove(_id); + if (LOG.isDebugEnabled()) + LOG.debug("Tracker removed: {}", getId()); } public void sessionWillPassivate(HttpSessionEvent se) { //take the tracker of the list of trackers (if its still there) //and ensure that we take ourselves out of the session so we are not saved - if (_rateTrackers != null) - _rateTrackers.remove(_id); + _rateTrackers.remove(_id); se.getSession().removeAttribute(__TRACKER); - if (LOG.isDebugEnabled()) LOG.debug("Value removed: "+_id); + if (LOG.isDebugEnabled()) LOG.debug("Value removed: {}", getId()); } public void sessionDidActivate(HttpSessionEvent se) { LOG.warn("Unexpected session activation"); } - - + public void expired() { - if (_rateTrackers != null && _trackerTimeoutQ != null) - { - long now = _trackerTimeoutQ.getNow(); - int latestIndex = _next == 0 ? (_timestamps.length-1) : (_next - 1 ); - long last=_timestamps[latestIndex]; - boolean hasRecentRequest = last != 0 && (now-last)<1000L; - - if (hasRecentRequest) - reschedule(); - else - _rateTrackers.remove(_id); - } + long now = _trackerTimeoutQ.getNow(); + int latestIndex = _next == 0 ? (_timestamps.length - 1) : (_next - 1); + long last = _timestamps[latestIndex]; + boolean hasRecentRequest = last != 0 && (now - last) < 1000L; + + if (hasRecentRequest) + reschedule(); + else + _rateTrackers.remove(_id); } @Override public String toString() { - return "RateTracker/"+_id+"/"+_type; + return "RateTracker/" + _id + "/" + _type; } - - } class FixedRateTracker extends RateTracker { public FixedRateTracker(String id, int type, int numRecentRequestsTracked) { - super(id,type,numRecentRequestsTracked); + super(id, type, numRecentRequestsTracked); } @Override @@ -1047,8 +1070,8 @@ public class DoSFilter implements Filter // and whether it should be expired synchronized (this) { - _timestamps[_next]=now; - _next= (_next+1)%_timestamps.length; + _timestamps[_next] = now; + _next = (_next + 1) % _timestamps.length; } return false; @@ -1057,7 +1080,7 @@ public class DoSFilter implements Filter @Override public String toString() { - return "Fixed"+super.toString(); + return "Fixed" + super.toString(); } } } diff --git a/jetty-servlets/src/main/resources/org/eclipse/jetty/servlets/jmx/DoSFilter-mbean.properties b/jetty-servlets/src/main/resources/org/eclipse/jetty/servlets/jmx/DoSFilter-mbean.properties index 6a1f31aa49..9523d23a3a 100644 --- a/jetty-servlets/src/main/resources/org/eclipse/jetty/servlets/jmx/DoSFilter-mbean.properties +++ b/jetty-servlets/src/main/resources/org/eclipse/jetty/servlets/jmx/DoSFilter-mbean.properties @@ -1,6 +1,6 @@ DoSFilter: Limit exposure to abuse from request flooding, whether malicious, or as a result of a misconfigured client. maxRequestsPerSec: maximum number of requests from a connection per second. Requests in excess of this are first delayed, then throttled. -delayMs: delay (in milliseconds) that is applied to all requests over the rate limit, before they are considered at all, 0 - no delay, -1 - reject request. +delayMs: delay (in milliseconds) that is applied to all requests over the rate limit, before they are considered at all, 0 - no delay, -1 - reject request. maxWaitMs: maximum amount of time (in milliseconds) the filter will blocking wait for the throttle semaphore. throttledRequests: number of requests over the rate limit able to be considered at once. throttleMs: amount of time (in milliseconds) to async wait for semaphore. @@ -9,4 +9,10 @@ maxIdleTrackerMs: maximum amount of time (in milliseconds) to keep track of requ insertHeaders: insert the DoSFilter headers into the response. trackSessions: usage rate is tracked by session if a session exists. remotePort: usage rate is tracked by IP+port (effectively connection) if session tracking is not used. -ipWhitelist: list of IP addresses that will not be rate limited.
\ No newline at end of file +enabled: whether this filter is enabled +whitelist: comma separated list of IP addresses that will not be rate limited. +clearWhitelist(): clears the list of IP addresses that will not be rate limited. +addWhitelistAddress(java.lang.String):ACTION: adds an IP address that will not be rate limited. +addWhitelistAddress(java.lang.String)[0]:address: the IP address that will not be rate limited. +removeWhitelistAddress(java.lang.String):ACTION: removes an IP address that will not be rate limited. +removeWhitelistAddress(java.lang.String)[0]:address: the IP address that will not be rate limited. diff --git a/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterJMXTest.java b/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterJMXTest.java new file mode 100644 index 0000000000..cd2ba3a857 --- /dev/null +++ b/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterJMXTest.java @@ -0,0 +1,88 @@ +// +// ======================================================================== +// Copyright (c) 1995-2013 Mort Bay Consulting Pty. Ltd. +// ------------------------------------------------------------------------ +// All rights reserved. This program and the accompanying materials +// are made available under the terms of the Eclipse Public License v1.0 +// and Apache License v2.0 which accompanies this distribution. +// +// The Eclipse Public License is available at +// http://www.eclipse.org/legal/epl-v10.html +// +// The Apache License v2.0 is available at +// http://www.opensource.org/licenses/apache2.0.php +// +// You may elect to redistribute this code under either of these licenses. +// ======================================================================== +// + +package org.eclipse.jetty.servlets; + +import java.lang.management.ManagementFactory; +import java.util.EnumSet; +import java.util.Set; +import javax.management.Attribute; +import javax.management.MBeanServer; +import javax.management.ObjectName; + +import org.eclipse.jetty.jmx.MBeanContainer; +import org.eclipse.jetty.server.Connector; +import org.eclipse.jetty.server.DispatcherType; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.nio.SelectChannelConnector; +import org.eclipse.jetty.servlet.FilterHolder; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.junit.Assert; +import org.junit.Test; + +public class DoSFilterJMXTest +{ + @Test + public void testDoSFilterJMX() throws Exception + { + Server server = new Server(); + Connector connector = new SelectChannelConnector(); + connector.setPort(0); + server.addConnector(connector); + + ServletContextHandler context = new ServletContextHandler(server, "/", ServletContextHandler.SESSIONS); + DoSFilter filter = new DoSFilter(); + FilterHolder holder = new FilterHolder(filter); + String name = "dos"; + holder.setName(name); + holder.setInitParameter(DoSFilter.MANAGED_ATTR_INIT_PARAM, "true"); + context.addFilter(holder, "/*", EnumSet.of(DispatcherType.REQUEST)); + context.setInitParameter(ServletContextHandler.MANAGED_ATTRIBUTES, name); + + MBeanServer mbeanServer = ManagementFactory.getPlatformMBeanServer(); + MBeanContainer mbeanContainer = new MBeanContainer(mbeanServer); + server.addBean(mbeanContainer); + server.getContainer().addEventListener(mbeanContainer); + + server.start(); + + String domain = DoSFilter.class.getPackage().getName(); + Set<ObjectName> mbeanNames = mbeanServer.queryNames(ObjectName.getInstance(domain + ":*"), null); + Assert.assertEquals(1, mbeanNames.size()); + ObjectName objectName = mbeanNames.iterator().next(); + + boolean value = (Boolean)mbeanServer.getAttribute(objectName, "enabled"); + mbeanServer.setAttribute(objectName, new Attribute("enabled", !value)); + Assert.assertEquals(!value, filter.isEnabled()); + + String whitelist = (String)mbeanServer.getAttribute(objectName, "whitelist"); + String address = "127.0.0.1"; + Assert.assertFalse(whitelist.contains(address)); + boolean result = (Boolean)mbeanServer.invoke(objectName, "addWhitelistAddress", new Object[]{address}, new String[]{String.class.getName()}); + Assert.assertTrue(result); + whitelist = (String)mbeanServer.getAttribute(objectName, "whitelist"); + Assert.assertTrue(whitelist.contains(address)); + + result = (Boolean)mbeanServer.invoke(objectName, "removeWhitelistAddress", new Object[]{address}, new String[]{String.class.getName()}); + Assert.assertTrue(result); + whitelist = (String)mbeanServer.getAttribute(objectName, "whitelist"); + Assert.assertFalse(whitelist.contains(address)); + + server.stop(); + } +} diff --git a/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterTest.java b/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterTest.java index 2542d95b0c..a05e64204d 100644 --- a/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterTest.java +++ b/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterTest.java @@ -18,18 +18,21 @@ package org.eclipse.jetty.servlets; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - +import java.util.ArrayList; +import java.util.List; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.eclipse.jetty.servlets.DoSFilter.RateTracker; import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Logger; +import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + public class DoSFilterTest extends AbstractDoSFilterTest { private static final Logger LOG = Log.getLogger(DoSFilterTest.class); @@ -70,6 +73,21 @@ public class DoSFilterTest extends AbstractDoSFilterTest assertFalse("Should not exceed as we sleep 300s for each hit and thus do less than 4 hits/s",exceeded); } + @Test + public void testWhitelist() throws Exception + { + DoSFilter filter = new DoSFilter(); + List<String> whitelist = new ArrayList<String>(); + whitelist.add("192.168.0.1"); + whitelist.add("10.0.0.0/8"); + Assert.assertTrue(filter.checkWhitelist(whitelist, "192.168.0.1")); + Assert.assertFalse(filter.checkWhitelist(whitelist, "192.168.0.2")); + Assert.assertFalse(filter.checkWhitelist(whitelist, "11.12.13.14")); + Assert.assertTrue(filter.checkWhitelist(whitelist, "10.11.12.13")); + Assert.assertTrue(filter.checkWhitelist(whitelist, "10.0.0.0")); + Assert.assertFalse(filter.checkWhitelist(whitelist, "0.0.0.0")); + } + private boolean hitRateTracker(DoSFilter doSFilter, int sleep) throws InterruptedException { boolean exceeded = false; |