Exemplo n.º 1
0
  /**
   * Sets up the handlers in the servlet chain. We setup a chain for every path + extension match
   * possibility. (i.e. if there a m path mappings and n extension mappings we have n*m chains).
   *
   * <p>If a chain consists of only the default servlet then we add it as an async handler, so that
   * resources can be served up directly without using blocking operations.
   *
   * <p>TODO: this logic is a bit convoluted at the moment, we should look at simplifying it
   */
  private ServletPathMatchesData setupServletChains() {
    // create the default servlet
    ServletHandler defaultServlet = null;
    final ManagedServlets servlets = deployment.getServlets();
    final ManagedFilters filters = deployment.getFilters();

    final Map<String, ServletHandler> extensionServlets = new HashMap<>();
    final Map<String, ServletHandler> pathServlets = new HashMap<>();

    final Set<String> pathMatches = new HashSet<>();
    final Set<String> extensionMatches = new HashSet<>();

    DeploymentInfo deploymentInfo = deployment.getDeploymentInfo();

    // loop through all filter mappings, and add them to the set of known paths
    for (FilterMappingInfo mapping : deploymentInfo.getFilterMappings()) {
      if (mapping.getMappingType() == FilterMappingInfo.MappingType.URL) {
        String path = mapping.getMapping();
        if (path.equals("*")) {
          // UNDERTOW-95, support this non-standard filter mapping
          path = "/*";
        }
        if (!path.startsWith("*.")) {
          pathMatches.add(path);
        } else {
          extensionMatches.add(path.substring(2));
        }
      }
    }

    // now loop through all servlets.
    for (Map.Entry<String, ServletHandler> entry : servlets.getServletHandlers().entrySet()) {
      final ServletHandler handler = entry.getValue();
      // add the servlet to the approprite path maps
      for (String path : handler.getManagedServlet().getServletInfo().getMappings()) {
        if (path.equals("/")) {
          // the default servlet
          pathMatches.add("/*");
          if (defaultServlet != null) {
            throw UndertowServletMessages.MESSAGES.twoServletsWithSameMapping(path);
          }
          defaultServlet = handler;
        } else if (!path.startsWith("*.")) {
          // either an exact or a /* based path match
          if (path.isEmpty()) {
            path = "/";
          }
          pathMatches.add(path);
          if (pathServlets.containsKey(path)) {
            throw UndertowServletMessages.MESSAGES.twoServletsWithSameMapping(path);
          }
          pathServlets.put(path, handler);
        } else {
          // an extension match based servlet
          String ext = path.substring(2);
          extensionMatches.add(ext);
          extensionServlets.put(ext, handler);
        }
      }
    }
    ServletHandler managedDefaultServlet = servlets.getServletHandler(DEFAULT_SERVLET_NAME);
    if (managedDefaultServlet == null) {
      // we always create a default servlet, even if it is not going to have any path mappings
      // registered
      managedDefaultServlet =
          servlets.addServlet(new ServletInfo(DEFAULT_SERVLET_NAME, DefaultServlet.class));
    }

    if (defaultServlet == null) {
      // no explicit default servlet was specified, so we register our mapping
      pathMatches.add("/*");
      defaultServlet = managedDefaultServlet;
    }

    final ServletPathMatchesData.Builder builder = ServletPathMatchesData.builder();

    // we now loop over every path in the application, and build up the patches based on this path
    // these paths contain both /* and exact matches.
    for (final String path : pathMatches) {
      // resolve the target servlet, will return null if this is the default servlet
      MatchData targetServletMatch =
          resolveServletForPath(path, pathServlets, extensionServlets, defaultServlet);

      final Map<DispatcherType, List<ManagedFilter>> noExtension =
          new EnumMap<>(DispatcherType.class);
      final Map<String, Map<DispatcherType, List<ManagedFilter>>> extension = new HashMap<>();
      // initalize the extension map. This contains all the filers in the noExtension map, plus
      // any filters that match the extension key
      for (String ext : extensionMatches) {
        extension.put(ext, new EnumMap<DispatcherType, List<ManagedFilter>>(DispatcherType.class));
      }

      // loop over all the filters, and add them to the appropriate map in the correct order
      for (final FilterMappingInfo filterMapping : deploymentInfo.getFilterMappings()) {
        ManagedFilter filter = filters.getManagedFilter(filterMapping.getFilterName());
        if (filterMapping.getMappingType() == FilterMappingInfo.MappingType.SERVLET) {
          if (targetServletMatch.handler != null) {
            if (filterMapping
                .getMapping()
                .equals(
                    targetServletMatch.handler.getManagedServlet().getServletInfo().getName())) {
              addToListMap(noExtension, filterMapping.getDispatcher(), filter);
            }
          }
          for (Map.Entry<String, Map<DispatcherType, List<ManagedFilter>>> entry :
              extension.entrySet()) {
            ServletHandler pathServlet = targetServletMatch.handler;
            boolean defaultServletMatch = targetServletMatch.defaultServlet;
            if (defaultServletMatch && extensionServlets.containsKey(entry.getKey())) {
              pathServlet = extensionServlets.get(entry.getKey());
            }

            if (filterMapping
                .getMapping()
                .equals(pathServlet.getManagedServlet().getServletInfo().getName())) {
              addToListMap(extension.get(entry.getKey()), filterMapping.getDispatcher(), filter);
            }
          }
        } else {
          if (filterMapping.getMapping().isEmpty()
              || !filterMapping.getMapping().startsWith("*.")) {
            if (isFilterApplicable(path, filterMapping.getMapping())) {
              addToListMap(noExtension, filterMapping.getDispatcher(), filter);
              for (Map<DispatcherType, List<ManagedFilter>> l : extension.values()) {
                addToListMap(l, filterMapping.getDispatcher(), filter);
              }
            }
          } else {
            addToListMap(
                extension.get(filterMapping.getMapping().substring(2)),
                filterMapping.getDispatcher(),
                filter);
          }
        }
      }
      // resolve any matches and add them to the builder
      if (path.endsWith("/*")) {
        String prefix = path.substring(0, path.length() - 2);
        // add the default non-extension match
        builder.addPrefixMatch(
            prefix,
            createHandler(
                deploymentInfo,
                targetServletMatch.handler,
                noExtension,
                targetServletMatch.matchedPath,
                targetServletMatch.defaultServlet),
            targetServletMatch.defaultServlet
                || targetServletMatch
                    .handler
                    .getManagedServlet()
                    .getServletInfo()
                    .isRequireWelcomeFileMapping());

        // build up the chain for each non-extension match
        for (Map.Entry<String, Map<DispatcherType, List<ManagedFilter>>> entry :
            extension.entrySet()) {
          ServletHandler pathServlet = targetServletMatch.handler;
          String pathMatch = targetServletMatch.matchedPath;

          boolean defaultServletMatch = targetServletMatch.defaultServlet;
          if (defaultServletMatch && extensionServlets.containsKey(entry.getKey())) {
            defaultServletMatch = false;
            pathServlet = extensionServlets.get(entry.getKey());
          }
          HttpHandler handler = pathServlet;
          if (!entry.getValue().isEmpty()) {
            handler =
                new FilterHandler(
                    entry.getValue(), deploymentInfo.isAllowNonStandardWrappers(), handler);
          }
          builder.addExtensionMatch(
              prefix,
              entry.getKey(),
              servletChain(
                  handler,
                  pathServlet.getManagedServlet(),
                  pathMatch,
                  deploymentInfo,
                  defaultServletMatch));
        }
      } else if (path.isEmpty()) {
        // the context root match
        builder.addExactMatch(
            "/",
            createHandler(
                deploymentInfo,
                targetServletMatch.handler,
                noExtension,
                targetServletMatch.matchedPath,
                targetServletMatch.defaultServlet));
      } else {
        // we need to check for an extension match, so paths like /exact.txt will have the correct
        // filter applied
        String lastSegment = path.substring(path.lastIndexOf('/'));
        if (lastSegment.contains(".")) {
          String ext = lastSegment.substring(lastSegment.lastIndexOf('.') + 1);
          if (extension.containsKey(ext)) {
            Map<DispatcherType, List<ManagedFilter>> extMap = extension.get(ext);
            builder.addExactMatch(
                path,
                createHandler(
                    deploymentInfo,
                    targetServletMatch.handler,
                    extMap,
                    targetServletMatch.matchedPath,
                    targetServletMatch.defaultServlet));
          } else {
            builder.addExactMatch(
                path,
                createHandler(
                    deploymentInfo,
                    targetServletMatch.handler,
                    noExtension,
                    targetServletMatch.matchedPath,
                    targetServletMatch.defaultServlet));
          }
        } else {
          builder.addExactMatch(
              path,
              createHandler(
                  deploymentInfo,
                  targetServletMatch.handler,
                  noExtension,
                  targetServletMatch.matchedPath,
                  targetServletMatch.defaultServlet));
        }
      }
    }

    // now setup name based mappings
    // these are used for name based dispatch
    for (Map.Entry<String, ServletHandler> entry : servlets.getServletHandlers().entrySet()) {
      final Map<DispatcherType, List<ManagedFilter>> filtersByDispatcher =
          new EnumMap<>(DispatcherType.class);
      for (final FilterMappingInfo filterMapping : deploymentInfo.getFilterMappings()) {
        ManagedFilter filter = filters.getManagedFilter(filterMapping.getFilterName());
        if (filterMapping.getMappingType() == FilterMappingInfo.MappingType.SERVLET) {
          if (filterMapping.getMapping().equals(entry.getKey())) {
            addToListMap(filtersByDispatcher, filterMapping.getDispatcher(), filter);
          }
        }
      }
      if (filtersByDispatcher.isEmpty()) {
        builder.addNameMatch(
            entry.getKey(),
            servletChain(
                entry.getValue(),
                entry.getValue().getManagedServlet(),
                null,
                deploymentInfo,
                false));
      } else {
        builder.addNameMatch(
            entry.getKey(),
            servletChain(
                new FilterHandler(
                    filtersByDispatcher,
                    deploymentInfo.isAllowNonStandardWrappers(),
                    entry.getValue()),
                entry.getValue().getManagedServlet(),
                null,
                deploymentInfo,
                false));
      }
    }

    return builder.build();
  }
  /**
   * Sets up the handlers in the servlet chain. We setup a chain for every path + extension match
   * possibility. (i.e. if there a m path mappings and n extension mappings we have n*m chains).
   *
   * <p>If a chain consists of only the default servlet then we add it as an async handler, so that
   * resources can be served up directly without using blocking operations.
   *
   * <p>TODO: this logic is a bit convoluted at the moment, we should look at simplifying it
   *
   * @param servletContext
   * @param threadSetupAction
   * @param listeners
   */
  private ServletPathMatches setupServletChains(
      final ServletContextImpl servletContext,
      final CompositeThreadSetupAction threadSetupAction,
      final ApplicationListeners listeners) {
    final List<Lifecycle> lifecycles = new ArrayList<Lifecycle>();
    // create the default servlet
    ServletChain defaultHandler = null;
    ServletHandler defaultServlet = null;

    final Map<String, ManagedFilter> managedFilterMap = new LinkedHashMap<String, ManagedFilter>();
    final Map<String, ServletHandler> allServlets = new HashMap<String, ServletHandler>();
    final Map<String, ServletHandler> extensionServlets = new HashMap<String, ServletHandler>();
    final Map<String, ServletHandler> pathServlets = new HashMap<String, ServletHandler>();

    final Set<String> pathMatches = new HashSet<String>();
    final Set<String> extensionMatches = new HashSet<String>();

    DeploymentInfo deploymentInfo = deployment.getDeploymentInfo();
    for (Map.Entry<String, FilterInfo> entry : deploymentInfo.getFilters().entrySet()) {
      final ManagedFilter mf = new ManagedFilter(entry.getValue(), servletContext);
      managedFilterMap.put(entry.getValue().getName(), mf);
      lifecycles.add(mf);
    }

    for (FilterMappingInfo mapping : deploymentInfo.getFilterMappings()) {
      if (mapping.getMappingType() == FilterMappingInfo.MappingType.URL) {
        String path = mapping.getMapping();
        if (!path.startsWith("*.")) {
          pathMatches.add(path);
        } else {
          extensionMatches.add(path.substring(2));
        }
      }
    }

    for (Map.Entry<String, ServletInfo> entry : deploymentInfo.getServlets().entrySet()) {
      ServletInfo servlet = entry.getValue();
      final ManagedServlet managedServlet = new ManagedServlet(servlet, servletContext);
      lifecycles.add(managedServlet);
      final ServletHandler handler = new ServletHandler(managedServlet);
      allServlets.put(entry.getKey(), handler);
      for (String path : entry.getValue().getMappings()) {
        if (path.equals("/")) {
          // the default servlet
          pathMatches.add("/*");
          if (pathServlets.containsKey("/*")) {
            throw UndertowServletMessages.MESSAGES.twoServletsWithSameMapping(path);
          }
          defaultServlet = handler;
          defaultHandler = servletChain(handler, managedServlet);
        } else if (!path.startsWith("*.")) {
          pathMatches.add(path);
          if (pathServlets.containsKey(path)) {
            throw UndertowServletMessages.MESSAGES.twoServletsWithSameMapping(path);
          }
          pathServlets.put(path, handler);
        } else {
          String ext = path.substring(2);
          extensionMatches.add(ext);
          extensionServlets.put(ext, handler);
        }
      }
    }

    if (defaultServlet == null) {
      final DefaultServletConfig config =
          deploymentInfo.getDefaultServletConfig() == null
              ? new DefaultServletConfig()
              : deploymentInfo.getDefaultServletConfig();
      DefaultServlet defaultInstance =
          new DefaultServlet(deployment, config, deploymentInfo.getWelcomePages());
      final ManagedServlet managedDefaultServlet =
          new ManagedServlet(
              new ServletInfo(
                  "io.undertow.DefaultServlet",
                  DefaultServlet.class,
                  new ImmediateInstanceFactory<Servlet>(defaultInstance)),
              servletContext);
      lifecycles.add(managedDefaultServlet);
      pathMatches.add("/*");
      defaultServlet = new ServletHandler(managedDefaultServlet);
      defaultHandler = new ServletChain(defaultServlet, managedDefaultServlet);
    }

    final ServletPathMatches.Builder builder = ServletPathMatches.builder();

    for (final String path : pathMatches) {
      ServletHandler targetServlet = resolveServletForPath(path, pathServlets);

      final Map<DispatcherType, List<ManagedFilter>> noExtension =
          new HashMap<DispatcherType, List<ManagedFilter>>();
      final Map<String, Map<DispatcherType, List<ManagedFilter>>> extension =
          new HashMap<String, Map<DispatcherType, List<ManagedFilter>>>();
      for (String ext : extensionMatches) {
        extension.put(ext, new HashMap<DispatcherType, List<ManagedFilter>>());
      }

      for (final FilterMappingInfo filterMapping : deploymentInfo.getFilterMappings()) {
        ManagedFilter filter = managedFilterMap.get(filterMapping.getFilterName());
        if (filterMapping.getMappingType() == FilterMappingInfo.MappingType.SERVLET) {
          if (targetServlet != null) {
            if (filterMapping
                .getMapping()
                .equals(targetServlet.getManagedServlet().getServletInfo().getName())) {
              addToListMap(noExtension, filterMapping.getDispatcher(), filter);
              for (Map<DispatcherType, List<ManagedFilter>> l : extension.values()) {
                addToListMap(l, filterMapping.getDispatcher(), filter);
              }
            }
          }
        } else {
          if (filterMapping.getMapping().isEmpty()
              || !filterMapping.getMapping().startsWith("*.")) {
            if (isFilterApplicable(path, filterMapping.getMapping())) {
              addToListMap(noExtension, filterMapping.getDispatcher(), filter);
              for (Map<DispatcherType, List<ManagedFilter>> l : extension.values()) {
                addToListMap(l, filterMapping.getDispatcher(), filter);
              }
            }
          } else {
            addToListMap(
                extension.get(filterMapping.getMapping().substring(2)),
                filterMapping.getDispatcher(),
                filter);
          }
        }
      }

      final ServletChain initialHandler;
      if (noExtension.isEmpty()) {
        if (targetServlet != null) {
          initialHandler = servletChain(targetServlet, targetServlet.getManagedServlet());
        } else {
          initialHandler = defaultHandler;
        }
      } else {
        FilterHandler handler;
        if (targetServlet != null) {
          handler = new FilterHandler(noExtension, targetServlet);
        } else {
          handler = new FilterHandler(noExtension, defaultServlet);
        }
        initialHandler =
            servletChain(
                handler,
                targetServlet == null
                    ? defaultServlet.getManagedServlet()
                    : targetServlet.getManagedServlet());
      }

      if (path.endsWith("/*")) {
        String prefix = path.substring(0, path.length() - 2);
        builder.addPrefixMatch(prefix, initialHandler);

        for (Map.Entry<String, Map<DispatcherType, List<ManagedFilter>>> entry :
            extension.entrySet()) {
          ServletHandler pathServlet = targetServlet;
          if (pathServlet == null) {
            pathServlet = extensionServlets.get(entry.getKey());
          }
          if (pathServlet == null) {
            pathServlet = defaultServlet;
          }
          HttpHandler handler = pathServlet;
          if (!entry.getValue().isEmpty()) {
            handler = new FilterHandler(entry.getValue(), handler);
          }
          builder.addExtensionMatch(
              prefix, entry.getKey(), servletChain(handler, pathServlet.getManagedServlet()));
        }
      } else if (path.isEmpty()) {
        builder.addExactMatch("/", initialHandler);
      } else {
        builder.addExactMatch(path, initialHandler);
      }
    }

    // now setup name based mappings
    // these are used for name based dispatch
    for (Map.Entry<String, ServletHandler> entry : allServlets.entrySet()) {
      final Map<DispatcherType, List<ManagedFilter>> filters =
          new HashMap<DispatcherType, List<ManagedFilter>>();
      for (final FilterMappingInfo filterMapping : deploymentInfo.getFilterMappings()) {
        ManagedFilter filter = managedFilterMap.get(filterMapping.getFilterName());
        if (filterMapping.getMappingType() == FilterMappingInfo.MappingType.SERVLET) {
          if (filterMapping.getMapping().equals(entry.getKey())) {
            addToListMap(filters, filterMapping.getDispatcher(), filter);
          }
        }
      }
      if (filters.isEmpty()) {
        builder.addNameMatch(
            entry.getKey(), servletChain(entry.getValue(), entry.getValue().getManagedServlet()));
      } else {
        builder.addNameMatch(
            entry.getKey(),
            servletChain(
                new FilterHandler(filters, entry.getValue()),
                entry.getValue().getManagedServlet()));
      }
    }

    builder.setDefaultServlet(defaultHandler);

    deployment.addLifecycleObjects(lifecycles);
    return builder.build();
  }